Showing preview only (778K chars total). Download the full file or copy to clipboard to get everything.
Repository: Wan-Video/Wan2.2
Branch: main
Commit: 42bf4cfaa384
Files: 69
Total size: 748.5 KB
Directory structure:
gitextract_8s745wnd/
├── .gitignore
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── generate.py
├── pyproject.toml
├── requirements.txt
├── requirements_animate.txt
├── requirements_s2v.txt
├── tests/
│ ├── README.md
│ └── test.sh
└── wan/
├── __init__.py
├── animate.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_animate_14B.py
│ ├── wan_i2v_A14B.py
│ ├── wan_s2v_14B.py
│ ├── wan_t2v_A14B.py
│ └── wan_ti2v_5B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ ├── sequence_parallel.py
│ ├── ulysses.py
│ └── util.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── animate/
│ │ ├── __init__.py
│ │ ├── animate_utils.py
│ │ ├── clip.py
│ │ ├── face_blocks.py
│ │ ├── model_animate.py
│ │ ├── motion_encoder.py
│ │ ├── preprocess/
│ │ │ ├── UserGuider.md
│ │ │ ├── __init__.py
│ │ │ ├── human_visualization.py
│ │ │ ├── pose2d.py
│ │ │ ├── pose2d_utils.py
│ │ │ ├── preprocess_data.py
│ │ │ ├── process_pipepline.py
│ │ │ ├── retarget_pose.py
│ │ │ ├── sam_utils.py
│ │ │ ├── utils.py
│ │ │ └── video_predictor.py
│ │ └── xlm_roberta.py
│ ├── attention.py
│ ├── model.py
│ ├── s2v/
│ │ ├── __init__.py
│ │ ├── audio_encoder.py
│ │ ├── audio_utils.py
│ │ ├── auxi_blocks.py
│ │ ├── model_s2v.py
│ │ ├── motioner.py
│ │ └── s2v_utils.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae2_1.py
│ └── vae2_2.py
├── speech2video.py
├── text2video.py
├── textimage2video.py
└── utils/
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
├── system_prompt.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
.DS_Store
.vscode*
tmp_examples*
new_checkpoint*
batch_test*
nohup*
================================================
FILE: INSTALL.md
================================================
# Installation Guide
## Install with pip
```bash
pip install .
pip install .[dev] # Installe aussi les outils de dev
```
## Install with Poetry
Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
To install all dependencies:
```bash
poetry install
```
### Handling `flash-attn` Installation Issues
If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
#### No-Build-Isolation Installation (Recommended)
```bash
poetry run pip install --upgrade pip setuptools wheel
poetry run pip install flash-attn --no-build-isolation
poetry install
```
#### Install from Git (Alternative)
```bash
poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
```
---
### Running the Model
Once the installation is complete, you can run **Wan2.2** using:
```bash
poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
#### Test
```bash
bash tests/test.sh
```
#### Format
```bash
black .
isort .
```
================================================
FILE: LICENSE.txt
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
.PHONY: format
format:
isort generate.py wan
yapf -i -r *.py generate.py wan
================================================
FILE: README.md
================================================
# Wan2.2
<p align="center">
<img src="assets/logo.png" width="400"/>
<p>
<p align="center">
💜 <a href="https://wan.video"><b>Wan</b></a>    |    🖥️ <a href="https://github.com/Wan-Video/Wan2.2">GitHub</a>    |   🤗 <a href="https://huggingface.co/Wan-AI/">Hugging Face</a>   |   🤖 <a href="https://modelscope.cn/organization/Wan-AI">ModelScope</a>   |    📑 <a href="https://arxiv.org/abs/2503.20314">Paper</a>    |    📑 <a href="https://wan.video/welcome?spm=a2ty_o02.30011076.0.0.6c9ee41eCcluqg">Blog</a>    |    💬 <a href="https://discord.gg/AKNgpMK4Yj">Discord</a>  
<br>
📕 <a href="https://alidocs.dingtalk.com/i/nodes/jb9Y4gmKWrx9eo4dCql9LlbYJGXn6lpz">使用指南(中文)</a>   |    📘 <a href="https://alidocs.dingtalk.com/i/nodes/EpGBa2Lm8aZxe5myC99MelA2WgN7R35y">User Guide(English)</a>   |   💬 <a href="https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg">WeChat(微信)</a>  
<br>
-----
[**Wan: Open and Advanced Large-Scale Video Generative Models**](https://arxiv.org/abs/2503.20314) <be>
We are excited to introduce **Wan2.2**, a major upgrade to our foundational video models. With **Wan2.2**, we have focused on incorporating the following innovations:
- 👍 **Effective MoE Architecture**: Wan2.2 introduces a Mixture-of-Experts (MoE) architecture into video diffusion models. By separating the denoising process cross timesteps with specialized powerful expert models, this enlarges the overall model capacity while maintaining the same computational cost.
- 👍 **Cinematic-level Aesthetics**: Wan2.2 incorporates meticulously curated aesthetic data, complete with detailed labels for lighting, composition, contrast, color tone, and more. This allows for more precise and controllable cinematic style generation, facilitating the creation of videos with customizable aesthetic preferences.
- 👍 **Complex Motion Generation**: Compared to Wan2.1, Wan2.2 is trained on a significantly larger data, with +65.6% more images and +83.2% more videos. This expansion notably enhances the model's generalization across multiple dimensions such as motions, semantics, and aesthetics, achieving TOP performance among all open-sourced and closed-sourced models.
- 👍 **Efficient High-Definition Hybrid TI2V**: Wan2.2 open-sources a 5B model built with our advanced Wan2.2-VAE that achieves a compression ratio of **16×16×4**. This model supports both text-to-video and image-to-video generation at 720P resolution with 24fps and can also run on consumer-grade graphics cards like 4090. It is one of the fastest **720P@24fps** models currently available, capable of serving both the industrial and academic sectors simultaneously.
## Video Demos
<div align="center">
<video src="https://github.com/user-attachments/assets/b63bfa58-d5d7-4de6-a1a2-98970b06d9a7" width="70%" poster=""> </video>
</div>
## 🔥 Latest News!!
* Nov 13, 2025: 👋 Wan2.2-Animate-14B has been integrated into Diffusers ([PR](https://github.com/huggingface/diffusers/pull/12526),[Weights](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)). Thanks to all community contributors. Enjoy!
* Sep 19, 2025: 💃 We introduct **[Wan2.2-Animate-14B](https://humanaigc.github.io/wan-animate)**, an unified model for character animation and replacement with holistic movement and expression replication. We released the [model weights](#model-download) and [inference code](#run-wan-animate). And you can try it on [wan.video](https://wan.video/), [ModelScope Studio](https://www.modelscope.cn/studios/Wan-AI/Wan2.2-Animate) or [HuggingFace Space](https://huggingface.co/spaces/Wan-AI/Wan2.2-Animate)!
* Aug 26, 2025: 🎵 We introduce **[Wan2.2-S2V-14B](https://humanaigc.github.io/wan-s2v-webpage)**, an audio-driven cinematic video generation model, including [inference code](#run-speech-to-video-generation), [model weights](#model-download), and [technical report](https://humanaigc.github.io/wan-s2v-webpage/content/wan-s2v.pdf)! Now you can try it on [wan.video](https://wan.video/), [ModelScope Gradio](https://www.modelscope.cn/studios/Wan-AI/Wan2.2-S2V) or [HuggingFace Gradio](https://huggingface.co/spaces/Wan-AI/Wan2.2-S2V)!
* Jul 28, 2025: 👋 We have open a [HF space](https://huggingface.co/spaces/Wan-AI/Wan-2.2-5B) using the TI2V-5B model. Enjoy!
* Jul 28, 2025: 👋 Wan2.2 has been integrated into ComfyUI ([CN](https://docs.comfy.org/zh-CN/tutorials/video/wan/wan2_2) | [EN](https://docs.comfy.org/tutorials/video/wan/wan2_2)). Enjoy!
* Jul 28, 2025: 👋 Wan2.2's T2V, I2V and TI2V have been integrated into Diffusers ([T2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) | [I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | [TI2V-5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)). Feel free to give it a try!
* Jul 28, 2025: 👋 We've released the inference code and model weights of **Wan2.2**.
* Sep 5, 2025: 👋 We add text-to-speech synthesis support with [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for Speech-to-Video generation task.
## Community Works
If your research or project builds upon [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) or [**Wan2.2**](https://github.com/Wan-Video/Wan2.2), and you would like more people to see it, please inform us.
- [Prompt Relay](https://github.com/GordonChen19/Prompt-Relay), a plug-and-play, inference-time method for temporal control in video generation. Prompt Relay improves video quality and gives users precise control over what happens at each moment in the video. Visit their [webpage](https://gordonchen19.github.io/Prompt-Relay/) for more details.
- [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model base on **Wan2.1** that achieves minute-scale, high-quality video synthesis at 19.5 FPS on a single H100 GPU (about 10 FPS on a single Ascend NPU) —without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Visit their [webpage](https://pku-yuangroup.github.io/Helios-Page/) for more details.
- [LightX2V](https://github.com/ModelTC/LightX2V), a lightweight and efficient video generation framework that integrates **Wan2.1** and **Wan2.2**, supporting multiple engineering acceleration techniques for fast inference. [LightX2V-HuggingFace](https://huggingface.co/lightx2v), offers a variety of Wan-based step-distillation models, quantized models, and lightweight VAE models.
- [HuMo](https://github.com/Phantom-video/HuMo) proposed a unified, human-centric framework based on **Wan** to produce high-quality, fine-grained, and controllable human videos from multimodal inputs—including text, images, and audio. Visit their [webpage](https://phantom-video.github.io/HuMo/) for more details.
- [FastVideo](https://github.com/hao-ai-lab/FastVideo) includes distilled **Wan** models with sparse attention that significanly speed up the inference time.
- [Cache-dit](https://github.com/vipshop/cache-dit) offers Fully Cache Acceleration support for **Wan2.2** MoE with DBCache, TaylorSeer and Cache CFG. Visit their [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) for more details.
- [Kijai's ComfyUI WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper) is an alternative implementation of **Wan** models for ComfyUI. Thanks to its Wan-only focus, it's on the frontline of getting cutting edge optimizations and hot research features, which are often hard to integrate into ComfyUI quickly due to its more rigid structure.
- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides comprehensive support for **Wan 2.2**, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training.
## 📑 Todo List
- Wan2.2 Text-to-Video
- [x] Multi-GPU Inference code of the A14B and 14B models
- [x] Checkpoints of the A14B and 14B models
- [x] ComfyUI integration
- [x] Diffusers integration
- Wan2.2 Image-to-Video
- [x] Multi-GPU Inference code of the A14B model
- [x] Checkpoints of the A14B model
- [x] ComfyUI integration
- [x] Diffusers integration
- Wan2.2 Text-Image-to-Video
- [x] Multi-GPU Inference code of the 5B model
- [x] Checkpoints of the 5B model
- [x] ComfyUI integration
- [x] Diffusers integration
- Wan2.2-S2V Speech-to-Video
- [x] Inference code of Wan2.2-S2V
- [x] Checkpoints of Wan2.2-S2V-14B
- [x] ComfyUI integration
- [x] Diffusers integration
- Wan2.2-Animate Character Animation and Replacement
- [x] Inference code of Wan2.2-Animate
- [x] Checkpoints of Wan2.2-Animate
- [x] ComfyUI integration
- [x] Diffusers integration
## Run Wan2.2
#### Installation
Clone the repo:
```sh
git clone https://github.com/Wan-Video/Wan2.2.git
cd Wan2.2
```
Install dependencies:
```sh
# Ensure torch >= 2.4.0
# If the installation of `flash_attn` fails, try installing the other packages first and install `flash_attn` last
pip install -r requirements.txt
# If you want to use CosyVoice to synthesize speech for Speech-to-Video Generation, please install requirements_s2v.txt additionally
pip install -r requirements_s2v.txt
```
#### Model Download
| Models | Download Links | Description |
|--------------------|---------------------------------------------------------------------------------------------------------------------------------------------|-------------|
| T2V-A14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | Text-to-Video MoE model, supports 480P & 720P |
| I2V-A14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | Image-to-Video MoE model, supports 480P & 720P |
| TI2V-5B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | High-compression VAE, T2V+I2V, supports 720P |
| S2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | Speech-to-Video model, supports 480P & 720P |
| Animate-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | Character animation and replacement | |
> 💡Note:
> The TI2V-5B model supports 720P video generation at **24 FPS**.
Download models using huggingface-cli:
``` sh
pip install "huggingface_hub[cli]"
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B
```
Download models using modelscope-cli:
``` sh
pip install modelscope
modelscope download Wan-AI/Wan2.2-T2V-A14B --local_dir ./Wan2.2-T2V-A14B
```
#### Run Text-to-Video Generation
This repository supports the `Wan2.2-T2V-A14B` Text-to-Video model and can simultaneously support video generation at 480P and 720P resolutions.
##### (1) Without Prompt Extension
To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step.
- Single-GPU inference
``` sh
python generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --offload_model True --convert_model_dtype --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
> 💡 This command can run on a GPU with at least 80GB VRAM.
> 💡If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True`, `--convert_model_dtype` and `--t5_cpu` options to reduce GPU memory usage.
- Multi-GPU inference using FSDP + DeepSpeed Ulysses
We use [PyTorch FSDP](https://docs.pytorch.org/docs/stable/fsdp.html) and [DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509) to accelerate inference.
``` sh
torchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
##### (2) Using Prompt Extension
Extending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension:
- Use the Dashscope API for extension.
- Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)).
- Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1).
- Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks.
- You can modify the model used for extension with the parameter `--prompt_extend_model`. For example:
```sh
DASH_API_KEY=your_key torchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'zh'
```
- Using a local model for extension.
- By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size.
- For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.
- For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
- Larger models generally provide better extension results but require more GPU memory.
- You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:
``` sh
torchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'zh'
```
#### Run Image-to-Video Generation
This repository supports the `Wan2.2-I2V-A14B` Image-to-Video model and can simultaneously support video generation at 480P and 720P resolutions.
- Single-GPU inference
```sh
python generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --offload_model True --convert_model_dtype --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> This command can run on a GPU with at least 80GB VRAM.
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
- Multi-GPU inference using FSDP + DeepSpeed Ulysses
```sh
torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
- Image-to-Video Generation without prompt
```sh
DASH_API_KEY=your_key torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --prompt '' --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --use_prompt_extend --prompt_extend_method 'dashscope'
```
> 💡The model can generate videos solely from the input image. You can use prompt extension to generate prompt from the image.
> The process of prompt extension can be referenced [here](#2-using-prompt-extention).
#### Run Text-Image-to-Video Generation
This repository supports the `Wan2.2-TI2V-5B` Text-Image-to-Video model and can support video generation at 720P resolutions.
- Single-GPU Text-to-Video inference
```sh
python generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --offload_model True --convert_model_dtype --t5_cpu --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage"
```
> 💡Unlike other tasks, the 720P resolution of the Text-Image-to-Video task is `1280*704` or `704*1280`.
> This command can run on a GPU with at least 24GB VRAM (e.g, RTX 4090 GPU).
> 💡If you are running on a GPU with at least 80GB VRAM, you can remove the `--offload_model True`, `--convert_model_dtype` and `--t5_cpu` options to speed up execution.
- Single-GPU Image-to-Video inference
```sh
python generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --offload_model True --convert_model_dtype --t5_cpu --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> 💡If the image parameter is configured, it is an Image-to-Video generation; otherwise, it defaults to a Text-to-Video generation.
> 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
- Multi-GPU inference using FSDP + DeepSpeed Ulysses
```sh
torchrun --nproc_per_node=8 generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --dit_fsdp --t5_fsdp --ulysses_size 8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> The process of prompt extension can be referenced [here](#2-using-prompt-extention).
#### Run Speech-to-Video Generation
This repository supports the `Wan2.2-S2V-14B` Speech-to-Video model and can simultaneously support video generation at 480P and 720P resolutions.
- Single-GPU Speech-to-Video inference
```sh
python generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --offload_model True --convert_model_dtype --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard." --image "examples/i2v_input.JPG" --audio "examples/talk.wav"
# Without setting --num_clip, the generated video length will automatically adjust based on the input audio length
# You can use CosyVoice to generate audio with --enable_tts
python generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --offload_model True --convert_model_dtype --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard." --image "examples/i2v_input.JPG" --enable_tts --tts_prompt_audio "examples/zero_shot_prompt.wav" --tts_prompt_text "希望你以后能够做的比我还好呦。" --tts_text "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
```
> 💡 This command can run on a GPU with at least 80GB VRAM.
- Multi-GPU inference using FSDP + DeepSpeed Ulysses
```sh
torchrun --nproc_per_node=8 generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard." --image "examples/i2v_input.JPG" --audio "examples/talk.wav"
```
- Pose + Audio driven generation
```sh
torchrun --nproc_per_node=8 generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "a person is singing" --image "examples/pose.png" --audio "examples/sing.MP3" --pose_video "./examples/pose.mp4"
```
> 💡For the Speech-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
> 💡The model can generate videos from audio input combined with reference image and optional text prompt.
> 💡The `--pose_video` parameter enables pose-driven generation, allowing the model to follow specific pose sequences while generating videos synchronized with audio input.
> 💡The `--num_clip` parameter controls the number of video clips generated, useful for quick preview with shorter generation time.
Please visit our project page to see more examples and learn about the scenarios suitable for this model.
#### Run Wan-Animate
Wan-Animate takes a video and a character image as input, and generates a video in either "animation" or "replacement" mode.
1. animation mode: The model generates a video of the character image that mimics the human motion in the input video.
2. replacement mode: The model replaces the character image with the input video.
Please visit our [project page](https://humanaigc.github.io/wan-animate) to see more examples and learn about the scenarios suitable for this model.
##### (1) Preprocessing
The input video should be preprocessed into several materials before be feed into the inference process. Please refer to the following processing flow, and more details about preprocessing can be found in [UserGuider](https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/animate/preprocess/UserGuider.md).
* For animation
```bash
python ./wan/modules/animate/preprocess/preprocess_data.py \
--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint \
--video_path ./examples/wan_animate/animate/video.mp4 \
--refer_path ./examples/wan_animate/animate/image.jpeg \
--save_path ./examples/wan_animate/animate/process_results \
--resolution_area 1280 720 \
--retarget_flag \
--use_flux
```
* For replacement
```bash
python ./wan/modules/animate/preprocess/preprocess_data.py \
--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint \
--video_path ./examples/wan_animate/replace/video.mp4 \
--refer_path ./examples/wan_animate/replace/image.jpeg \
--save_path ./examples/wan_animate/replace/process_results \
--resolution_area 1280 720 \
--iterations 3 \
--k 7 \
--w_len 1 \
--h_len 1 \
--replace_flag
```
##### (2) Run in animation mode
* Single-GPU inference
```bash
python generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/animate/process_results/ --refert_num 1
```
* Multi-GPU inference using FSDP + DeepSpeed Ulysses
```bash
python -m torch.distributed.run --nnodes 1 --nproc_per_node 8 generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/animate/process_results/ --refert_num 1 --dit_fsdp --t5_fsdp --ulysses_size 8
```
* Diffusers Pipeline
```python
from diffusers import WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
device = "cuda:0"
dtype = torch.bfloat16
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
pipe = WanAnimatePipeline.from_pretrained(model_id torch_dtype=dtype)
pipe.to(device)
seed = 42
prompt = "People in the video are doing actions."
# Animation
image = load_image("/path/to/animate/reference/image/src_ref.png")
pose_video = load_video("/path/to/animate/pose/video/src_pose.mp4")
face_video = load_video("/path/to/animate/face/video/src_face.mp4")
animate_video = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
mode="animate",
segment_frame_length=77, # clip_len in original code
prev_segment_conditioning_frames=1, # refert_num in original code
guidance_scale=1.0,
num_inference_steps=20,
generator=torch.Generator(device=device).manual_seed(seed),
).frames[0]
export_to_video(animate_video, "diffusers_animate.mp4", fps=30)
```
##### (3) Run in replacement mode
* Single-GPU inference
```bash
python generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/replace/process_results/ --refert_num 1 --replace_flag --use_relighting_lora
```
* Multi-GPU inference using FSDP + DeepSpeed Ulysses
```bash
python -m torch.distributed.run --nnodes 1 --nproc_per_node 8 generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/replace/process_results/src_pose.mp4 --refert_num 1 --replace_flag --use_relighting_lora --dit_fsdp --t5_fsdp --ulysses_size 8
```
* Diffusers Pipeline
```python
# create pipeline as in the Animation code ☝️
# Replacement
image = load_image("/path/to/replace/reference/image/src_ref.png")
pose_video = load_video("/path/to/replace/pose/video/src_pose.mp4")
face_video = load_video("/path/to/replace/face/video/src_face.mp4")
background_video = load_video("/path/to/replace/background/video/src_bg.mp4")
mask_video = load_video("/path/to/replace/mask/video/src_mask.mp4")
replace_video = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
mode="replace",
segment_frame_length=77, # clip_len in original code
prev_segment_conditioning_frames=1, # refert_num in original code
guidance_scale=1.0,
num_inference_steps=20,
generator=torch.Generator(device=device).manual_seed(seed),
).frames[0]
export_to_video(replace_video, "diffusers_replace.mp4", fps=30)
```
> 💡 If you're using **Wan-Animate**, we do not recommend using LoRA models trained on `Wan2.2`, since weight changes during training may lead to unexpected behavior.
## Computational Efficiency on Different GPUs
We test the computational efficiency of different **Wan2.2** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.
<div align="center">
<img src="assets/comp_effic.png" alt="" style="width: 80%;" />
</div>
> The parameter settings for the tests presented in this table are as follows:
> (1) Multi-GPU: 14B: `--ulysses_size 4/8 --dit_fsdp --t5_fsdp`, 5B: `--ulysses_size 4/8 --offload_model True --convert_model_dtype --t5_cpu`; Single-GPU: 14B: `--offload_model True --convert_model_dtype`, 5B: `--offload_model True --convert_model_dtype --t5_cpu`
(--convert_model_dtype converts model parameter types to config.param_dtype);
> (2) The distributed testing utilizes the built-in FSDP and Ulysses implementations, with FlashAttention3 deployed on Hopper architecture GPUs;
> (3) Tests were run without the `--use_prompt_extend` flag;
> (4) Reported results are the average of multiple samples taken after the warm-up phase.
-------
## Introduction of Wan2.2
**Wan2.2** builds on the foundation of Wan2.1 with notable improvements in generation quality and model capability. This upgrade is driven by a series of key technical innovations, mainly including the Mixture-of-Experts (MoE) architecture, upgraded training data, and high-compression video generation.
##### (1) Mixture-of-Experts (MoE) Architecture
Wan2.2 introduces Mixture-of-Experts (MoE) architecture into the video generation diffusion model. MoE has been widely validated in large language models as an efficient approach to increase total model parameters while keeping inference cost nearly unchanged. In Wan2.2, the A14B model series adopts a two-expert design tailored to the denoising process of diffusion models: a high-noise expert for the early stages, focusing on overall layout; and a low-noise expert for the later stages, refining video details. Each expert model has about 14B parameters, resulting in a total of 27B parameters but only 14B active parameters per step, keeping inference computation and GPU memory nearly unchanged.
<div align="center">
<img src="assets/moe_arch.png" alt="" style="width: 90%;" />
</div>
The transition point between the two experts is determined by the signal-to-noise ratio (SNR), a metric that decreases monotonically as the denoising step $t$ increases. At the beginning of the denoising process, $t$ is large and the noise level is high, so the SNR is at its minimum, denoted as ${SNR}_{min}$. In this stage, the high-noise expert is activated. We define a threshold step ${t}_{moe}$ corresponding to half of the ${SNR}_{min}$, and switch to the low-noise expert when $t<{t}_{moe}$.
<div align="center">
<img src="assets/moe_2.png" alt="" style="width: 90%;" />
</div>
To validate the effectiveness of the MoE architecture, four settings are compared based on their validation loss curves. The baseline **Wan2.1** model does not employ the MoE architecture. Among the MoE-based variants, the **Wan2.1 & High-Noise Expert** reuses the Wan2.1 model as the low-noise expert while uses the Wan2.2's high-noise expert, while the **Wan2.1 & Low-Noise Expert** uses Wan2.1 as the high-noise expert and employ the Wan2.2's low-noise expert. The **Wan2.2 (MoE)** (our final version) achieves the lowest validation loss, indicating that its generated video distribution is closest to ground-truth and exhibits superior convergence.
##### (2) Efficient High-Definition Hybrid TI2V
To enable more efficient deployment, Wan2.2 also explores a high-compression design. In addition to the 27B MoE models, a 5B dense model, i.e., TI2V-5B, is released. It is supported by a high-compression Wan2.2-VAE, which achieves a $T\times H\times W$ compression ratio of $4\times16\times16$, increasing the overall compression rate to 64 while maintaining high-quality video reconstruction. With an additional patchification layer, the total compression ratio of TI2V-5B reaches $4\times32\times32$. Without specific optimization, TI2V-5B can generate a 5-second 720P video in under 9 minutes on a single consumer-grade GPU, ranking among the fastest 720P@24fps video generation models. This model also natively supports both text-to-video and image-to-video tasks within a single unified framework, covering both academic research and practical applications.
<div align="center">
<img src="assets/vae.png" alt="" style="width: 80%;" />
</div>
##### Comparisons to SOTAs
We compared Wan2.2 with leading closed-source commercial models on our new Wan-Bench 2.0, evaluating performance across multiple crucial dimensions. The results demonstrate that Wan2.2 achieves superior performance compared to these leading models.
<div align="center">
<img src="assets/performance.png" alt="" style="width: 90%;" />
</div>
## Citation
If you find our work helpful, please cite us.
```
@article{wan2025,
title={Wan: Open and Advanced Large-Scale Video Generative Models},
author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
journal = {arXiv preprint arXiv:2503.20314},
year={2025}
}
```
## License Agreement
The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt).
## Acknowledgements
We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.
## Contact Us
If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)!
================================================
FILE: generate.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.distributed.util import init_distributed_group
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import merge_video_audio, save_video, str2bool
EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"animate-14B": {
"prompt": "视频中的人在做动作",
"video": "",
"pose": "",
"mask": "",
},
"s2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
"audio":
"examples/talk.wav",
"tts_prompt_audio":
"examples/zero_shot_prompt.wav",
"tts_prompt_text":
"希望你以后能够做的比我还好呦。",
"tts_text":
"收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
args.image = EXAMPLE_PROMPT[args.task]["image"]
if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
args.audio = EXAMPLE_PROMPT[args.task]["audio"]
if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
if args.task == "i2v-A14B":
assert args.image is not None, "Please specify the image path for i2v."
cfg = WAN_CONFIGS[args.task]
if args.sample_steps is None:
args.sample_steps = cfg.sample_steps
if args.sample_shift is None:
args.sample_shift = cfg.sample_shift
if args.sample_guide_scale is None:
args.sample_guide_scale = cfg.sample_guide_scale
if args.frame_num is None:
args.frame_num = cfg.frame_num
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
if not 's2v' in args.task:
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-A14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames of video are generated. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5.")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU.")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="zh",
choices=["zh", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=None,
help="Classifier free guidance scale.")
parser.add_argument(
"--convert_model_dtype",
action="store_true",
default=False,
help="Whether to convert model paramerters dtype.")
# animate
parser.add_argument(
"--src_root_path",
type=str,
default=None,
help="The file of the process output path. Default None.")
parser.add_argument(
"--refert_num",
type=int,
default=77,
help="How many frames used for temporal guidance. Recommended to be 1 or 5."
)
parser.add_argument(
"--replace_flag",
action="store_true",
default=False,
help="Whether to use replace.")
parser.add_argument(
"--use_relighting_lora",
action="store_true",
default=False,
help="Whether to use relighting lora.")
# following args only works for s2v
parser.add_argument(
"--num_clip",
type=int,
default=None,
help="Number of video clips to generate, the whole video will not exceed the length of audio."
)
parser.add_argument(
"--audio",
type=str,
default=None,
help="Path to the audio file, e.g. wav, mp3")
parser.add_argument(
"--enable_tts",
action="store_true",
default=False,
help="Use CosyVoice to synthesis audio")
parser.add_argument(
"--tts_prompt_audio",
type=str,
default=None,
help="Path to the tts prompt audio file, e.g. wav, mp3. Must be greater than 16khz, and between 5s to 15s.")
parser.add_argument(
"--tts_prompt_text",
type=str,
default=None,
help="Content to the tts prompt audio. If provided, must exactly match tts_prompt_audio")
parser.add_argument(
"--tts_text",
type=str,
default=None,
help="Text wish to synthesize")
parser.add_argument(
"--pose_video",
type=str,
default=None,
help="Provide Dw-pose sequence to do Pose Driven")
parser.add_argument(
"--start_from_ref",
action="store_true",
default=False,
help="whether set the reference image as the starting point for generation"
)
parser.add_argument(
"--infer_frames",
type=int,
default=80,
help="Number of frames per clip, 48 or 80 or others (must be multiple of 4) for 14B s2v"
)
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1
), f"sequence parallel are not supported in non-distributed environments."
if args.ulysses_size > 1:
assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
init_distributed_group()
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model,
task=args.task,
is_vl=args.image is not None)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
task=args.task,
is_vl=args.image is not None,
device=rank)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
logging.info(f"Input prompt: {args.prompt}")
img = None
if args.image is not None:
img = Image.open(args.image).convert("RGB")
logging.info(f"Input image: {args.image}")
# prompt extend
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
image=img,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
if "t2v" in args.task:
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "ti2v" in args.task:
logging.info("Creating WanTI2V pipeline.")
wan_ti2v = wan.WanTI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_ti2v.generate(
args.prompt,
img=img,
size=SIZE_CONFIGS[args.size],
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "animate" in args.task:
logging.info("Creating Wan-Animate pipeline.")
wan_animate = wan.WanAnimate(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
use_relighting_lora=args.use_relighting_lora
)
logging.info(f"Generating video ...")
video = wan_animate.generate(
src_root_path=args.src_root_path,
replace_flag=args.replace_flag,
refert_num = args.refert_num,
clip_len=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "s2v" in args.task:
logging.info("Creating WanS2V pipeline.")
wan_s2v = wan.WanS2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_s2v.generate(
input_prompt=args.prompt,
ref_image_path=args.image,
audio_path=args.audio,
enable_tts=args.enable_tts,
tts_prompt_audio=args.tts_prompt_audio,
tts_prompt_text=args.tts_prompt_text,
tts_text=args.tts_text,
num_repeat=args.num_clip,
pose_video=args.pose_video,
max_area=MAX_AREA_CONFIGS[args.size],
infer_frames=args.infer_frames,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model,
init_first_frame=args.start_from_ref,
)
else:
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.mp4'
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
logging.info(f"Saving generated video to {args.save_file}")
save_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
if "s2v" in args.task:
if args.enable_tts is False:
merge_video_audio(video_path=args.save_file, audio_path=args.audio)
else:
merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
del video
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "wan"
version = "2.2.0"
description = "Wan: Open and Advanced Large-Scale Video Generative Models"
authors = [
{ name = "Wan Team", email = "wan.ai@alibabacloud.com" }
]
license = { file = "LICENSE.txt" }
readme = "README.md"
requires-python = ">=3.10,<4.0"
dependencies = [
"torch>=2.4.0",
"torchvision>=0.19.0",
"opencv-python>=4.9.0.80",
"diffusers>=0.31.0",
"transformers>=4.49.0",
"tokenizers>=0.20.3",
"accelerate>=1.1.1",
"tqdm",
"imageio",
"easydict",
"ftfy",
"dashscope",
"imageio-ffmpeg",
"flash_attn",
"numpy>=1.23.5,<2"
]
[project.optional-dependencies]
dev = [
"pytest",
"black",
"flake8",
"isort",
"mypy",
"huggingface-hub[cli]"
]
[project.urls]
homepage = "https://wanxai.com"
documentation = "https://github.com/Wan-Video/Wan2.2"
repository = "https://github.com/Wan-Video/Wan2.2"
huggingface = "https://huggingface.co/Wan-AI/"
modelscope = "https://modelscope.cn/organization/Wan-AI"
discord = "https://discord.gg/p5XbdQV7"
[tool.setuptools]
packages = ["wan"]
[tool.setuptools.package-data]
"wan" = ["**/*.py"]
[tool.black]
line-length = 88
[tool.isort]
profile = "black"
[tool.mypy]
strict = true
================================================
FILE: requirements.txt
================================================
torch>=2.4.0
torchvision>=0.19.0
torchaudio
opencv-python>=4.9.0.80
diffusers>=0.31.0
transformers>=4.49.0,<=4.51.3
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
imageio[ffmpeg]
easydict
ftfy
dashscope
imageio-ffmpeg
flash_attn
numpy>=1.23.5,<2
================================================
FILE: requirements_animate.txt
================================================
decord
peft
onnxruntime
pandas
matplotlib
-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2
loguru
sentencepiece
================================================
FILE: requirements_s2v.txt
================================================
openai-whisper
HyperPyYAML
onnxruntime
inflect
wetext
omegaconf
conformer
hydra-core
lightning
rich
gdown
matplotlib
wget
pyarrow
pyworld
librosa
decord
modelscope
GitPython
================================================
FILE: tests/README.md
================================================
Put all your models (Wan2.2-T2V-A14B, Wan2.2-I2V-A14B, Wan2.2-TI2V-5B) in a folder and specify the max GPU number you want to use.
```bash
bash ./tests/test.sh <local model dir> <gpu number>
```
================================================
FILE: tests/test.sh
================================================
#!/bin/bash
set -x
unset NCCL_DEBUG
if [ "$#" -eq 2 ]; then
MODEL_DIR=$(realpath "$1")
GPUS=$2
else
echo "Usage: $0 <local model dir> <gpu number>"
exit 1
fi
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT" || exit 1
PY_FILE=./generate.py
function t2v_A14B() {
CKPT_DIR="$MODEL_DIR/Wan2.2-T2V-A14B"
# # 1-GPU Test
# echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B 1-GPU Test: "
# python $PY_FILE --task t2v-A14B --size 480*832 --ckpt_dir $CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 1280*720 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 480*832 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function i2v_A14B() {
CKPT_DIR="$MODEL_DIR/Wan2.2-I2V-A14B"
# echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
# python $PY_FILE --task i2v-A14B --size 832*480 --ckpt_dir $CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 1280*720 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 480*832 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function ti2v_5B() {
CKPT_DIR="$MODEL_DIR/Wan2.2-TI2V-5B"
# # 1-GPU Test
# echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v 1-GPU Test: "
# python $PY_FILE --task ti2v-5B --size 1280*704 --ckpt_dir $CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 1280*704 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 704*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B i2v Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 704*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --image "examples/i2v_input.JPG"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B i2v Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 1280*704 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang 'en' --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --image "examples/i2v_input.JPG"
}
t2v_A14B
i2v_A14B
ti2v_5B
================================================
FILE: wan/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from . import configs, distributed, modules
from .image2video import WanI2V
from .speech2video import WanS2V
from .text2video import WanT2V
from .textimage2video import WanTI2V
from .animate import WanAnimate
================================================
FILE: wan/animate.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import os
import cv2
import types
from copy import deepcopy
from functools import partial
from einops import rearrange
import numpy as np
import torch
import torch.distributed as dist
from peft import set_peft_model_state_dict
from decord import VideoReader
from tqdm import tqdm
import torch.nn.functional as F
from .distributed.fsdp import shard_model
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
from .distributed.util import get_world_size
from .modules.animate import WanAnimateModel
from .modules.animate import CLIPModel
from .modules.t5 import T5EncoderModel
from .modules.vae2_1 import Wan2_1_VAE
from .modules.animate.animate_utils import TensorList, get_loraconfig
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanAnimate:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
use_relighting_lora=False
):
r"""
Initializes the generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_sp (`bool`, *optional*, defaults to False):
Enable distribution strategy of sequence parallel.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
use_relighting_lora (`bool`, *optional*, defaults to False):
Whether to use relighting lora for character replacement.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.init_on_cpu = init_on_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
if t5_fsdp or dit_fsdp or use_sp:
self.init_on_cpu = False
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.clip = CLIPModel(
dtype=torch.float16,
device=self.device,
checkpoint_path=os.path.join(checkpoint_dir,
config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
self.vae = Wan2_1_VAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanAnimate from {checkpoint_dir}")
if not dit_fsdp:
self.noise_model = WanAnimateModel.from_pretrained(
checkpoint_dir,
torch_dtype=self.param_dtype,
device_map=self.device)
else:
self.noise_model = WanAnimateModel.from_pretrained(
checkpoint_dir, torch_dtype=self.param_dtype)
self.noise_model = self._configure_model(
model=self.noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype,
use_lora=use_relighting_lora,
checkpoint_dir=checkpoint_dir,
config=config
)
if use_sp:
self.sp_size = get_world_size()
else:
self.sp_size = 1
self.sample_neg_prompt = config.sample_neg_prompt
self.sample_prompt = config.prompt
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype, use_lora, checkpoint_dir, config):
"""
Configures a model object. This includes setting evaluation modes,
applying distributed parallel strategy, and handling device placement.
Args:
model (torch.nn.Module):
The model instance to configure.
use_sp (`bool`):
Enable distribution strategy of sequence parallel.
dit_fsdp (`bool`):
Enable FSDP sharding for DiT model.
shard_fn (callable):
The function to apply FSDP sharding.
convert_model_dtype (`bool`):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
Returns:
torch.nn.Module:
The configured model.
"""
model.eval().requires_grad_(False)
if use_sp:
for block in model.blocks:
block.self_attn.forward = types.MethodType(
sp_attn_forward, block.self_attn)
model.use_context_parallel = True
if dist.is_initialized():
dist.barrier()
if use_lora:
logging.info("Loading Relighting Lora. ")
lora_config = get_loraconfig(
transformer=model,
rank=128,
alpha=128
)
model.add_adapter(lora_config)
lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint)
peft_state_dict = torch.load(lora_path)["state_dict"]
set_peft_model_state_dict(model, peft_state_dict)
if dit_fsdp:
model = shard_fn(model, use_lora=use_lora)
else:
if convert_model_dtype:
model.to(self.param_dtype)
if not self.init_on_cpu:
model.to(self.device)
return model
def inputs_padding(self, array, target_len):
idx = 0
flip = False
target_array = []
while len(target_array) < target_len:
target_array.append(deepcopy(array[idx]))
if flip:
idx -= 1
else:
idx += 1
if idx == 0 or idx == len(array) - 1:
flip = not flip
return target_array[:target_len]
def get_valid_len(self, real_len, clip_len=81, overlap=1):
real_clip_len = clip_len - overlap
last_clip_num = (real_len - overlap) % real_clip_len
if last_clip_num == 0:
extra = 0
else:
extra = real_clip_len - last_clip_num
target_len = real_len + extra
return target_len
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
msk = mask_pixel_values.clone()
msk[:, :mask_len] = 1
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
return msk
def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
ori_height = img_ori.shape[0]
ori_width = img_ori.shape[1]
channel = img_ori.shape[2]
img_pad = np.zeros((height, width, channel))
if channel == 1:
img_pad[:, :, 0] = padding_color[0]
else:
img_pad[:, :, 0] = padding_color[0]
img_pad[:, :, 1] = padding_color[1]
img_pad[:, :, 2] = padding_color[2]
if (ori_height / ori_width) > (height / width):
new_width = int(height / ori_height * ori_width)
img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
padding = int((width - new_width) / 2)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img_pad[:, padding: padding + new_width, :] = img
else:
new_height = int(width / ori_width * ori_height)
img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
padding = int((height - new_height) / 2)
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img_pad[padding: padding + new_height, :, :] = img
img_pad = np.uint8(img_pad)
return img_pad
def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
pose_video_reader = VideoReader(src_pose_path)
pose_len = len(pose_video_reader)
pose_idxs = list(range(pose_len))
cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()
face_video_reader = VideoReader(src_face_path)
face_len = len(face_video_reader)
face_idxs = list(range(face_len))
face_images = face_video_reader.get_batch(face_idxs).asnumpy()
height, width = cond_images[0].shape[:2]
refer_images = cv2.imread(src_ref_path)[..., ::-1]
refer_images = self.padding_resize(refer_images, height=height, width=width)
return cond_images, face_images, refer_images
def prepare_source_for_replace(self, src_bg_path, src_mask_path):
bg_video_reader = VideoReader(src_bg_path)
bg_len = len(bg_video_reader)
bg_idxs = list(range(bg_len))
bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()
mask_video_reader = VideoReader(src_mask_path)
mask_len = len(mask_video_reader)
mask_idxs = list(range(mask_len))
mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
mask_images = mask_images[:, :, :, 0] / 255
return bg_images, mask_images
def generate(
self,
src_root_path,
replace_flag=False,
clip_len=77,
refert_num=1,
shift=5.0,
sample_solver='dpm++',
sampling_steps=20,
guide_scale=1,
input_prompt="",
n_prompt="",
seed=-1,
offload_model=True,
):
r"""
Generates video frames from input image using diffusion process.
Args:
src_root_path ('str'):
Process output path
replace_flag (`bool`, *optional*, defaults to False):
Whether to use character replace.
clip_len (`int`, *optional*, defaults to 77):
How many frames to generate per clips. The number should be 4n+1
refert_num (`int`, *optional*, defaults to 1):
How many frames used for temporal guidance. Recommended to be 1 or 5.
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter.
sample_solver (`str`, *optional*, defaults to 'dpm++'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 20):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):
Classifier-free guidance scale. We only use it for expression control.
In most cases, it's not necessary and faster generation can be achieved without it.
When expression adjustments are needed, you may consider using this feature.
input_prompt (`str`):
Text prompt for content generation. We don't recommend custom prompts (although they work)
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N, H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames
- H: Frame height
- W: Frame width
"""
assert refert_num == 1 or refert_num == 5, "refert_num should be 1 or 5."
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if n_prompt == "":
n_prompt = self.sample_neg_prompt
if input_prompt == "":
input_prompt = self.sample_prompt
src_pose_path = os.path.join(src_root_path, "src_pose.mp4")
src_face_path = os.path.join(src_root_path, "src_face.mp4")
src_ref_path = os.path.join(src_root_path, "src_ref.png")
cond_images, face_images, refer_images = self.prepare_source(src_pose_path=src_pose_path, src_face_path=src_face_path, src_ref_path=src_ref_path)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
real_frame_len = len(cond_images)
target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
logging.info('real frames: {} target frames: {}'.format(real_frame_len, target_len))
cond_images = self.inputs_padding(cond_images, target_len)
face_images = self.inputs_padding(face_images, target_len)
if replace_flag:
src_bg_path = os.path.join(src_root_path, "src_bg.mp4")
src_mask_path = os.path.join(src_root_path, "src_mask.mp4")
bg_images, mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
bg_images = self.inputs_padding(bg_images, target_len)
mask_images = self.inputs_padding(mask_images, target_len)
height, width = refer_images.shape[:2]
start = 0
end = clip_len
all_out_frames = []
while True:
if start + refert_num >= len(cond_images):
break
if start == 0:
mask_reft_len = 0
else:
mask_reft_len = refert_num
batch = {
"conditioning_pixel_values": torch.zeros(1, 3, clip_len, height, width),
"bg_pixel_values": torch.zeros(1, 3, clip_len, height, width),
"mask_pixel_values": torch.zeros(1, 1, clip_len, height, width),
"face_pixel_values": torch.zeros(1, 3, clip_len, 512, 512),
"refer_pixel_values": torch.zeros(1, 3, height, width),
"refer_t_pixel_values": torch.zeros(refert_num, 3, height, width)
}
batch["conditioning_pixel_values"] = rearrange(
torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1),
"t h w c -> 1 c t h w",
)
batch["face_pixel_values"] = rearrange(
torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1),
"t h w c -> 1 c t h w",
)
batch["refer_pixel_values"] = rearrange(
torch.tensor(refer_images / 127.5 - 1), "h w c -> 1 c h w"
)
if start > 0:
batch["refer_t_pixel_values"] = rearrange(
out_frames[0, :, -refert_num:].clone().detach(),
"c t h w -> t c h w",
)
batch["refer_t_pixel_values"] = rearrange(batch["refer_t_pixel_values"],
"t c h w -> 1 c t h w",
)
if replace_flag:
batch["bg_pixel_values"] = rearrange(
torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1),
"t h w c -> 1 c t h w",
)
batch["mask_pixel_values"] = rearrange(
torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]),
"t h w c -> 1 t c h w",
)
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(device=self.device, dtype=torch.bfloat16)
ref_pixel_values = batch["refer_pixel_values"]
refer_t_pixel_values = batch["refer_t_pixel_values"]
conditioning_pixel_values = batch["conditioning_pixel_values"]
face_pixel_values = batch["face_pixel_values"]
B, _, H, W = ref_pixel_values.shape
T = clip_len
lat_h = H // 8
lat_w = W // 8
lat_t = T // 4 + 1
target_shape = [lat_t + 1, lat_h, lat_w]
noise = [
torch.randn(
16,
target_shape[0],
target_shape[1],
target_shape[2],
dtype=torch.float32,
device=self.device,
generator=seed_g,
)
]
max_seq_len = int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size
if max_seq_len % self.sp_size != 0:
raise ValueError(f"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}")
with (
torch.autocast(device_type=str(self.device), dtype=torch.bfloat16, enabled=True),
torch.no_grad()
):
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
latents = noise
pose_latents_no_ref = self.vae.encode(conditioning_pixel_values.to(torch.bfloat16))
pose_latents_no_ref = torch.stack(pose_latents_no_ref)
pose_latents = torch.cat([pose_latents_no_ref], dim=2)
ref_pixel_values = rearrange(ref_pixel_values, "t c h w -> 1 c t h w")
ref_latents = self.vae.encode(ref_pixel_values.to(torch.bfloat16))
ref_latents = torch.stack(ref_latents)
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device)
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=self.device)
img = ref_pixel_values[0, :, 0]
clip_context = self.clip.visual([img[:, None, :, :]]).to(dtype=torch.bfloat16, device=self.device)
if mask_reft_len > 0:
if replace_flag:
bg_pixel_values = batch["bg_pixel_values"]
y_reft = self.vae.encode(
[
torch.concat([refer_t_pixel_values[0, :, :mask_reft_len], bg_pixel_values[0, :, mask_reft_len:]], dim=1).to(self.device)
]
)[0]
mask_pixel_values = 1 - batch["mask_pixel_values"]
mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
else:
y_reft = self.vae.encode(
[
torch.concat(
[
torch.nn.functional.interpolate(refer_t_pixel_values[0, :, :mask_reft_len].cpu(),
size=(H, W), mode="bicubic"),
torch.zeros(3, T - mask_reft_len, H, W),
],
dim=1,
).to(self.device)
]
)[0]
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
else:
if replace_flag:
bg_pixel_values = batch["bg_pixel_values"]
mask_pixel_values = 1 - batch["mask_pixel_values"]
mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
y_reft = self.vae.encode(
[
torch.concat(
[
bg_pixel_values[0],
],
dim=1,
).to(self.device)
]
)[0]
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
else:
y_reft = self.vae.encode(
[
torch.concat(
[
torch.zeros(3, T - mask_reft_len, H, W),
],
dim=1,
).to(self.device)
]
)[0]
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=self.device)
y = torch.concat([y_ref, y_reft], dim=1)
arg_c = {
"context": context,
"seq_len": max_seq_len,
"clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
"y": [y],
"pose_latents": pose_latents,
"face_pixel_values": face_pixel_values,
}
if guide_scale > 1:
face_pixel_values_uncond = face_pixel_values * 0 - 1
arg_null = {
"context": context_null,
"seq_len": max_seq_len,
"clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
"y": [y],
"pose_latents": pose_latents,
"face_pixel_values": face_pixel_values_uncond,
}
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
noise_pred_cond = TensorList(
self.noise_model(TensorList(latent_model_input), t=timestep, **arg_c)
)
if guide_scale > 1:
noise_pred_uncond = TensorList(
self.noise_model(
TensorList(latent_model_input), t=timestep, **arg_null
)
)
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond
)
else:
noise_pred = noise_pred_cond
temp_x0 = sample_scheduler.step(
noise_pred[0].unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g,
)[0]
latents[0] = temp_x0.squeeze(0)
x0 = latents
x0 = [x.to(dtype=torch.float32) for x in x0]
out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]]))
if start != 0:
out_frames = out_frames[:, :, refert_num:]
all_out_frames.append(out_frames.cpu())
start += clip_len - refert_num
end += clip_len - refert_num
videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
return videos[0] if self.rank == 0 else None
================================================
FILE: wan/configs/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from .wan_i2v_A14B import i2v_A14B
from .wan_s2v_14B import s2v_14B
from .wan_t2v_A14B import t2v_A14B
from .wan_ti2v_5B import ti2v_5B
from .wan_animate_14B import animate_14B
WAN_CONFIGS = {
't2v-A14B': t2v_A14B,
'i2v-A14B': i2v_A14B,
'ti2v-5B': ti2v_5B,
'animate-14B': animate_14B,
's2v-14B': s2v_14B,
}
SIZE_CONFIGS = {
'720*1280': (720, 1280),
'1280*720': (1280, 720),
'480*832': (480, 832),
'832*480': (832, 480),
'704*1280': (704, 1280),
'1280*704': (1280, 704),
'1024*704': (1024, 704),
'704*1024': (704, 1024),
}
MAX_AREA_CONFIGS = {
'720*1280': 720 * 1280,
'1280*720': 1280 * 720,
'480*832': 480 * 832,
'832*480': 832 * 480,
'704*1280': 704 * 1280,
'1280*704': 1280 * 704,
'1024*704': 1024 * 704,
'704*1024': 704 * 1024,
}
SUPPORTED_SIZES = {
't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
'ti2v-5B': ('704*1280', '1280*704'),
's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704',
'704*1024', '704*1280', '1280*704'),
'animate-14B': ('720*1280', '1280*720')
}
================================================
FILE: wan/configs/shared_config.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
#------------------------ Wan shared config ------------------------#
wan_shared_cfg = EasyDict()
# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.bfloat16
wan_shared_cfg.text_len = 512
# transformer
wan_shared_cfg.param_dtype = torch.bfloat16
# inference
wan_shared_cfg.num_train_timesteps = 1000
wan_shared_cfg.sample_fps = 16
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
wan_shared_cfg.frame_num = 81
================================================
FILE: wan/configs/wan_animate_14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan animate 14B ------------------------#
animate_14B = EasyDict(__name__='Config: Wan animate 14B')
animate_14B.update(wan_shared_cfg)
animate_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
animate_14B.t5_tokenizer = 'google/umt5-xxl'
animate_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
animate_14B.clip_tokenizer = 'xlm-roberta-large'
animate_14B.lora_checkpoint = 'relighting_lora.ckpt'
# vae
animate_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
animate_14B.vae_stride = (4, 8, 8)
# transformer
animate_14B.patch_size = (1, 2, 2)
animate_14B.dim = 5120
animate_14B.ffn_dim = 13824
animate_14B.freq_dim = 256
animate_14B.num_heads = 40
animate_14B.num_layers = 40
animate_14B.window_size = (-1, -1)
animate_14B.qk_norm = True
animate_14B.cross_attn_norm = True
animate_14B.eps = 1e-6
animate_14B.use_face_encoder = True
animate_14B.motion_encoder_dim = 512
# inference
animate_14B.sample_shift = 5.0
animate_14B.sample_steps = 20
animate_14B.sample_guide_scale = 1.0
animate_14B.frame_num = 77
animate_14B.sample_fps = 30
animate_14B.prompt = '视频中的人在做动作'
================================================
FILE: wan/configs/wan_i2v_A14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan I2V A14B ------------------------#
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
i2v_A14B.update(wan_shared_cfg)
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
# vae
i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
i2v_A14B.vae_stride = (4, 8, 8)
# transformer
i2v_A14B.patch_size = (1, 2, 2)
i2v_A14B.dim = 5120
i2v_A14B.ffn_dim = 13824
i2v_A14B.freq_dim = 256
i2v_A14B.num_heads = 40
i2v_A14B.num_layers = 40
i2v_A14B.window_size = (-1, -1)
i2v_A14B.qk_norm = True
i2v_A14B.cross_attn_norm = True
i2v_A14B.eps = 1e-6
i2v_A14B.low_noise_checkpoint = 'low_noise_model'
i2v_A14B.high_noise_checkpoint = 'high_noise_model'
# inference
i2v_A14B.sample_shift = 5.0
i2v_A14B.sample_steps = 40
i2v_A14B.boundary = 0.900
i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
================================================
FILE: wan/configs/wan_s2v_14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan S2V 14B ------------------------#
s2v_14B = EasyDict(__name__='Config: Wan S2V 14B')
s2v_14B.update(wan_shared_cfg)
# t5
s2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
s2v_14B.t5_tokenizer = 'google/umt5-xxl'
# vae
s2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
s2v_14B.vae_stride = (4, 8, 8)
# wav2vec
s2v_14B.wav2vec = "wav2vec2-large-xlsr-53-english"
s2v_14B.num_heads = 40
# transformer
s2v_14B.transformer = EasyDict(
__name__="Config: Transformer config for WanModel_S2V")
s2v_14B.transformer.patch_size = (1, 2, 2)
s2v_14B.transformer.dim = 5120
s2v_14B.transformer.ffn_dim = 13824
s2v_14B.transformer.freq_dim = 256
s2v_14B.transformer.num_heads = 40
s2v_14B.transformer.num_layers = 40
s2v_14B.transformer.window_size = (-1, -1)
s2v_14B.transformer.qk_norm = True
s2v_14B.transformer.cross_attn_norm = True
s2v_14B.transformer.eps = 1e-6
s2v_14B.transformer.enable_adain = True
s2v_14B.transformer.adain_mode = "attn_norm"
s2v_14B.transformer.audio_inject_layers = [
0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39
]
s2v_14B.transformer.zero_init = True
s2v_14B.transformer.zero_timestep = True
s2v_14B.transformer.enable_motioner = False
s2v_14B.transformer.add_last_motion = True
s2v_14B.transformer.trainable_token = False
s2v_14B.transformer.enable_tsm = False
s2v_14B.transformer.enable_framepack = True
s2v_14B.transformer.framepack_drop_mode = 'padd'
s2v_14B.transformer.audio_dim = 1024
s2v_14B.transformer.motion_frames = 73
s2v_14B.transformer.cond_dim = 16
# inference
s2v_14B.sample_neg_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
s2v_14B.drop_first_motion = True
s2v_14B.sample_shift = 3
s2v_14B.sample_steps = 40
s2v_14B.sample_guide_scale = 4.5
================================================
FILE: wan/configs/wan_t2v_A14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan T2V A14B ------------------------#
t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
t2v_A14B.update(wan_shared_cfg)
# t5
t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
# vae
t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
t2v_A14B.vae_stride = (4, 8, 8)
# transformer
t2v_A14B.patch_size = (1, 2, 2)
t2v_A14B.dim = 5120
t2v_A14B.ffn_dim = 13824
t2v_A14B.freq_dim = 256
t2v_A14B.num_heads = 40
t2v_A14B.num_layers = 40
t2v_A14B.window_size = (-1, -1)
t2v_A14B.qk_norm = True
t2v_A14B.cross_attn_norm = True
t2v_A14B.eps = 1e-6
t2v_A14B.low_noise_checkpoint = 'low_noise_model'
t2v_A14B.high_noise_checkpoint = 'high_noise_model'
# inference
t2v_A14B.sample_shift = 12.0
t2v_A14B.sample_steps = 40
t2v_A14B.boundary = 0.875
t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
================================================
FILE: wan/configs/wan_ti2v_5B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict
from .shared_config import wan_shared_cfg
#------------------------ Wan TI2V 5B ------------------------#
ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
ti2v_5B.update(wan_shared_cfg)
# t5
ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
# vae
ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
ti2v_5B.vae_stride = (4, 16, 16)
# transformer
ti2v_5B.patch_size = (1, 2, 2)
ti2v_5B.dim = 3072
ti2v_5B.ffn_dim = 14336
ti2v_5B.freq_dim = 256
ti2v_5B.num_heads = 24
ti2v_5B.num_layers = 30
ti2v_5B.window_size = (-1, -1)
ti2v_5B.qk_norm = True
ti2v_5B.cross_attn_norm = True
ti2v_5B.eps = 1e-6
# inference
ti2v_5B.sample_fps = 24
ti2v_5B.sample_shift = 5.0
ti2v_5B.sample_steps = 50
ti2v_5B.sample_guide_scale = 5.0
ti2v_5B.frame_num = 121
================================================
FILE: wan/distributed/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
================================================
FILE: wan/distributed/fsdp.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
from functools import partial
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=True,
use_lora=False
):
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype),
device_id=device_id,
sync_module_states=sync_module_states,
use_orig_params=True if use_lora else False)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()
================================================
FILE: wan/distributed/sequence_parallel.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from ..modules.model import sinusoidal_embedding_1d
from .ulysses import distributed_attention
from .util import gather_forward, get_rank, get_world_size
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@torch.amp.autocast('cuda', enabled=False)
def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_world_size()
sp_rank = get_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def sp_dit_forward(
self,
x,
t,
context,
seq_len,
y=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# Context Parallel
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# Context Parallel
x = gather_forward(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
x = distributed_attention(
half(q),
half(k),
half(v),
seq_lens,
window_size=self.window_size,
)
# output
x = x.flatten(2)
x = self.o(x)
return x
================================================
FILE: wan/distributed/ulysses.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
from ..modules.attention import flash_attention
from .util import all_to_all
def distributed_attention(
q,
k,
v,
seq_lens,
window_size=(-1, -1),
):
"""
Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
please refer to https://arxiv.org/pdf/2309.14509
Args:
q: [B, Lq // p, Nq, C1].
k: [B, Lk // p, Nk, C1].
v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
seq_lens: [B], length of each sequence in batch
window_size: (left right). If not (-1, -1), apply sliding window local attention.
"""
if not dist.is_initialized():
raise ValueError("distributed group should be initialized.")
b = q.shape[0]
# gather q/k/v sequence
q = all_to_all(q, scatter_dim=2, gather_dim=1)
k = all_to_all(k, scatter_dim=2, gather_dim=1)
v = all_to_all(v, scatter_dim=2, gather_dim=1)
# apply attention
x = flash_attention(
q,
k,
v,
k_lens=seq_lens,
window_size=window_size,
)
# scatter q/k/v sequence
x = all_to_all(x, scatter_dim=1, gather_dim=2)
return x
================================================
FILE: wan/distributed/util.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
def init_distributed_group():
"""r initialize sequence parallel group.
"""
if not dist.is_initialized():
dist.init_process_group(backend='nccl')
def get_rank():
return dist.get_rank()
def get_world_size():
return dist.get_world_size()
def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
"""
`scatter` along one dimension and `gather` along another.
"""
world_size = get_world_size()
if world_size > 1:
inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
outputs = [torch.empty_like(u) for u in inputs]
dist.all_to_all(outputs, inputs, group=group, **kwargs)
x = torch.cat(outputs, dim=gather_dim).contiguous()
return x
def all_gather(tensor):
world_size = dist.get_world_size()
if world_size == 1:
return [tensor]
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor)
return tensor_list
def gather_forward(input, dim):
# skip if world_size == 1
world_size = dist.get_world_size()
if world_size == 1:
return input
# gather sequence
output = all_gather(input)
return torch.cat(output, dim=dim).contiguous()
================================================
FILE: wan/image2video.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
from .distributed.util import get_world_size
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae2_1 import Wan2_1_VAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanI2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_sp (`bool`, *optional*, defaults to False):
Enable distribution strategy of sequence parallel.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.init_on_cpu = init_on_cpu
self.num_train_timesteps = config.num_train_timesteps
self.boundary = config.boundary
self.param_dtype = config.param_dtype
if t5_fsdp or dit_fsdp or use_sp:
self.init_on_cpu = False
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = Wan2_1_VAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.low_noise_model = WanModel.from_pretrained(
checkpoint_dir, subfolder=config.low_noise_checkpoint)
self.low_noise_model = self._configure_model(
model=self.low_noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
self.high_noise_model = WanModel.from_pretrained(
checkpoint_dir, subfolder=config.high_noise_checkpoint)
self.high_noise_model = self._configure_model(
model=self.high_noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
if use_sp:
self.sp_size = get_world_size()
else:
self.sp_size = 1
self.sample_neg_prompt = config.sample_neg_prompt
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype):
"""
Configures a model object. This includes setting evaluation modes,
applying distributed parallel strategy, and handling device placement.
Args:
model (torch.nn.Module):
The model instance to configure.
use_sp (`bool`):
Enable distribution strategy of sequence parallel.
dit_fsdp (`bool`):
Enable FSDP sharding for DiT model.
shard_fn (callable):
The function to apply FSDP sharding.
convert_model_dtype (`bool`):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
Returns:
torch.nn.Module:
The configured model.
"""
model.eval().requires_grad_(False)
if use_sp:
for block in model.blocks:
block.self_attn.forward = types.MethodType(
sp_attn_forward, block.self_attn)
model.forward = types.MethodType(sp_dit_forward, model)
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
model = shard_fn(model)
else:
if convert_model_dtype:
model.to(self.param_dtype)
if not self.init_on_cpu:
model.to(self.device)
return model
def _prepare_model_for_timestep(self, t, boundary, offload_model):
r"""
Prepares and returns the required model for the current timestep.
Args:
t (torch.Tensor):
current timestep.
boundary (`int`):
The timestep threshold. If `t` is at or above this value,
the `high_noise_model` is considered as the required model.
offload_model (`bool`):
A flag intended to control the offloading behavior.
Returns:
torch.nn.Module:
The active model on the target device for the current timestep.
"""
if t.item() >= boundary:
required_model_name = 'high_noise_model'
offload_model_name = 'low_noise_model'
else:
required_model_name = 'low_noise_model'
offload_model_name = 'high_noise_model'
if offload_model or self.init_on_cpu:
if next(getattr(
self,
offload_model_name).parameters()).device.type == 'cuda':
getattr(self, offload_model_name).to('cpu')
if next(getattr(
self,
required_model_name).parameters()).device.type == 'cpu':
getattr(self, required_model_name).to(self.device)
return getattr(self, required_model_name)
def generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
If tuple, the first guide_scale will be used for low noise model and
the second guide_scale will be used for high noise model.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
# preprocess
guide_scale = (guide_scale, guide_scale) if isinstance(
guide_scale, float) else guide_scale
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // self.vae_stride[0] + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, F - 1, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
noop_no_sync)
no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
noop_no_sync)
# evaluation mode
with (
torch.amp.autocast('cuda', dtype=self.param_dtype),
torch.no_grad(),
no_sync_low_noise(),
no_sync_high_noise(),
):
boundary = self.boundary * self.num_train_timesteps
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'seq_len': max_seq_len,
'y': [y],
}
arg_null = {
'context': context_null,
'seq_len': max_seq_len,
'y': [y],
}
if offload_model:
torch.cuda.empty_cache()
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
model = self._prepare_model_for_timestep(
t, boundary, offload_model)
sample_guide_scale = guide_scale[1] if t.item(
) >= boundary else guide_scale[0]
noise_pred_cond = model(
latent_model_input, t=timestep, **arg_c)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = model(
latent_model_input, t=timestep, **arg_null)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + sample_guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent]
del latent_model_input, timestep
if offload_model:
self.low_noise_model.cpu()
self.high_noise_model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent, x0
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
================================================
FILE: wan/modules/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae2_1 import Wan2_1_VAE
from .vae2_2 import Wan2_2_VAE
__all__ = [
'Wan2_1_VAE',
'Wan2_2_VAE',
'WanModel',
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
'HuggingfaceTokenizer',
'flash_attention',
]
================================================
FILE: wan/modules/animate/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .model_animate import WanAnimateModel
from .clip import CLIPModel
__all__ = ['WanAnimateModel', 'CLIPModel']
================================================
FILE: wan/modules/animate/animate_utils.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import numbers
from peft import LoraConfig
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
target_modules = []
for name, module in transformer.named_modules():
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
target_modules.append(name)
transformer_lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
init_lora_weights=init_lora_weights,
target_modules=target_modules,
)
return transformer_lora_config
class TensorList(object):
def __init__(self, tensors):
"""
tensors: a list of torch.Tensor objects. No need to have uniform shape.
"""
assert isinstance(tensors, (list, tuple))
assert all(isinstance(u, torch.Tensor) for u in tensors)
assert len(set([u.ndim for u in tensors])) == 1
assert len(set([u.dtype for u in tensors])) == 1
assert len(set([u.device for u in tensors])) == 1
self.tensors = tensors
def to(self, *args, **kwargs):
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
def size(self, dim):
assert dim == 0, 'only support get the 0th size'
return len(self.tensors)
def pow(self, *args, **kwargs):
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
def squeeze(self, dim):
assert dim != 0
if dim > 0:
dim -= 1
return TensorList([u.squeeze(dim) for u in self.tensors])
def type(self, *args, **kwargs):
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
def type_as(self, other):
assert isinstance(other, (torch.Tensor, TensorList))
if isinstance(other, torch.Tensor):
return TensorList([u.type_as(other) for u in self.tensors])
else:
return TensorList([u.type(other.dtype) for u in self.tensors])
@property
def dtype(self):
return self.tensors[0].dtype
@property
def device(self):
return self.tensors[0].device
@property
def ndim(self):
return 1 + self.tensors[0].ndim
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def __add__(self, other):
return self._apply(other, lambda u, v: u + v)
def __radd__(self, other):
return self._apply(other, lambda u, v: v + u)
def __sub__(self, other):
return self._apply(other, lambda u, v: u - v)
def __rsub__(self, other):
return self._apply(other, lambda u, v: v - u)
def __mul__(self, other):
return self._apply(other, lambda u, v: u * v)
def __rmul__(self, other):
return self._apply(other, lambda u, v: v * u)
def __floordiv__(self, other):
return self._apply(other, lambda u, v: u // v)
def __truediv__(self, other):
return self._apply(other, lambda u, v: u / v)
def __rfloordiv__(self, other):
return self._apply(other, lambda u, v: v // u)
def __rtruediv__(self, other):
return self._apply(other, lambda u, v: v / u)
def __pow__(self, other):
return self._apply(other, lambda u, v: u ** v)
def __rpow__(self, other):
return self._apply(other, lambda u, v: v ** u)
def __neg__(self):
return TensorList([-u for u in self.tensors])
def __iter__(self):
for tensor in self.tensors:
yield tensor
def __repr__(self):
return 'TensorList: \n' + repr(self.tensors)
def _apply(self, other, op):
if isinstance(other, (list, tuple, TensorList)) or (
isinstance(other, torch.Tensor) and (
other.numel() > 1 or other.ndim > 1
)
):
assert len(other) == len(self.tensors)
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
elif isinstance(other, numbers.Number) or (
isinstance(other, torch.Tensor) and (
other.numel() == 1 and other.ndim <= 1
)
):
return TensorList([op(u, other) for u in self.tensors])
else:
raise TypeError(
f'unsupported operand for *: "TensorList" and "{type(other)}"'
)
================================================
FILE: wan/modules/animate/clip.py
================================================
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from ..attention import flash_attention
from ..tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
'CLIPModel',
]
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type='token',
pre_norm=True,
post_norm=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
if image_size % patch_size != 0:
print(
'[WARNING] image_size is not divisible by patch_size',
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3,
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc':
self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool':
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim')
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def _clip(pretrained=False,
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if 'siglip' in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([
T.Resize((model.image_size, model.image_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(
pretrained=False,
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
**kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=dtype,
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu'))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path,
seq_len=self.model.max_text_len - 2,
clean='whitespace')
def visual(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([
F.interpolate(
u.transpose(0, 1),
size=size,
mode='bicubic',
align_corners=False) for u in videos
])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out
================================================
FILE: wan/modules/animate/face_blocks.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from ...distributed.util import gather_forward, get_rank, get_world_size
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
except ImportError:
flash_attn_func = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, hidden_dim)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
if use_context_parallel:
q = gather_forward(q, dim=1)
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
# Compute attention.
attn = attention(
q,
k,
v,
max_seqlen_q=q.shape[1],
batch_size=q.shape[0],
)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
if use_context_parallel:
attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output
================================================
FILE: wan/modules/animate/model_animate.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.loaders import PeftAdapterMixin
from ...distributed.sequence_parallel import (
distributed_attention,
gather_forward,
get_rank,
get_world_size,
)
from ..model import (
Head,
WanAttentionBlock,
WanLayerNorm,
WanRMSNorm,
WanModel,
WanSelfAttention,
flash_attention,
rope_params,
sinusoidal_embedding_1d,
rope_apply
)
from .face_blocks import FaceEncoder, FaceAdapter
from .motion_encoder import Generator
class HeadAnimate(Head):
def forward(self, x, e):
"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class WanAnimateSelfAttention(WanSelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAnimateCrossAttention(WanSelfAttention):
def __init__(
self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
use_img_emb=True
):
super().__init__(
dim,
num_heads,
window_size,
qk_norm,
eps
)
self.use_img_emb = use_img_emb
if use_img_emb:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
if self.use_img_emb:
context_img = context[:, :257]
context = context[:, 257:]
else:
context = context
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
if self.use_img_emb:
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
if self.use_img_emb:
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
class WanAnimateAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
use_img_emb=True):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(
dim, eps, elementwise_affine=True
) if cross_attn_norm else nn.Identity()
self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim)
)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim),
torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(),
torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim),
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
patch_size=(1, 2, 2),
text_len=512,
in_dim=36,
dim=5120,
ffn_dim=13824,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=40,
num_layers=40,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
motion_encoder_dim=512,
use_context_parallel=False,
use_img_emb=True):
super().__init__()
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.motion_encoder_dim = motion_encoder_dim
self.use_context_parallel = use_context_parallel
self.use_img_emb = use_img_emb
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
])
# head
self.head = HeadAnimate(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], dim=1)
self.img_emb = MLPProj(1280, dim)
# initialize weights
self.init_weights()
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
for x_, pose_latents_ in zip(x, pose_latents):
x_[:, :, 1:] += pose_latents_
b,c,T,h,w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
return x, motion_vec
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
if block_idx % 5 == 0:
adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
x = residual_out + x
return x
def forward(
self,
x,
t,
clip_fea,
context,
seq_len,
y=None,
pose_latents=None,
face_pixel_values=None
):
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float()
)
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if self.use_img_emb:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
if self.use_context_parallel:
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
for idx, block in enumerate(self.blocks):
x = block(x, **kwargs)
x = self.after_transformer_block(idx, x, motion_vec)
# head
x = self.head(x, e)
if self.use_context_parallel:
x = gather_forward(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
================================================
FILE: wan/modules/animate/motion_encoder.py
================================================
# Modified from ``https://github.com/wyhsirius/LIA``
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
def custom_qr(input_tensor):
original_dtype = input_tensor.dtype
if original_dtype == torch.bfloat16:
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
return q.to(original_dtype), r.to(original_dtype)
return torch.linalg.qr(input_tensor)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, self.kernel, pad=self.pad)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
class EqualConv2d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
bias=bias and not activate))
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class EncoderApp(nn.Module):
def __init__(self, size, w_dim=512):
super(EncoderApp, self).__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16
}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
self.convs = nn.ModuleList()
self.convs.append(ConvLayer(3, channels[size], 1))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
self.convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
def forward(self, x):
res = []
h = x
for conv in self.convs:
h = conv(h)
res.append(h)
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
class Encoder(nn.Module):
def __init__(self, size, dim=512, dim_motion=20):
super(Encoder, self).__init__()
# appearance netmork
self.net_app = EncoderApp(size, dim)
# motion network
fc = [EqualLinear(dim, dim)]
for i in range(3):
fc.append(EqualLinear(dim, dim))
fc.append(EqualLinear(dim, dim_motion))
self.fc = nn.Sequential(*fc)
def enc_app(self, x):
h_source = self.net_app(x)
return h_source
def enc_motion(self, x):
h, _ = self.net_app(x)
h_motion = self.fc(h)
return h_motion
class Direction(nn.Module):
def __init__(self, motion_dim):
super(Direction, self).__init__()
self.weight = nn.Parameter(torch.randn(512, motion_dim))
def forward(self, input):
weight = self.weight + 1e-8
Q, R = custom_qr(weight)
if input is None:
return Q
else:
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
out = torch.matmul(input_diag, Q.T)
out = torch.sum(out, dim=1)
return out
class Synthesis(nn.Module):
def __init__(self, motion_dim):
super(Synthesis, self).__init__()
self.direction = Direction(motion_dim)
class Generator(nn.Module):
def __init__(self, size, style_dim=512, motion_dim=20):
super().__init__()
self.enc = Encoder(size, style_dim, motion_dim)
gitextract_8s745wnd/
├── .gitignore
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── generate.py
├── pyproject.toml
├── requirements.txt
├── requirements_animate.txt
├── requirements_s2v.txt
├── tests/
│ ├── README.md
│ └── test.sh
└── wan/
├── __init__.py
├── animate.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_animate_14B.py
│ ├── wan_i2v_A14B.py
│ ├── wan_s2v_14B.py
│ ├── wan_t2v_A14B.py
│ └── wan_ti2v_5B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ ├── sequence_parallel.py
│ ├── ulysses.py
│ └── util.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── animate/
│ │ ├── __init__.py
│ │ ├── animate_utils.py
│ │ ├── clip.py
│ │ ├── face_blocks.py
│ │ ├── model_animate.py
│ │ ├── motion_encoder.py
│ │ ├── preprocess/
│ │ │ ├── UserGuider.md
│ │ │ ├── __init__.py
│ │ │ ├── human_visualization.py
│ │ │ ├── pose2d.py
│ │ │ ├── pose2d_utils.py
│ │ │ ├── preprocess_data.py
│ │ │ ├── process_pipepline.py
│ │ │ ├── retarget_pose.py
│ │ │ ├── sam_utils.py
│ │ │ ├── utils.py
│ │ │ └── video_predictor.py
│ │ └── xlm_roberta.py
│ ├── attention.py
│ ├── model.py
│ ├── s2v/
│ │ ├── __init__.py
│ │ ├── audio_encoder.py
│ │ ├── audio_utils.py
│ │ ├── auxi_blocks.py
│ │ ├── model_s2v.py
│ │ ├── motioner.py
│ │ └── s2v_utils.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae2_1.py
│ └── vae2_2.py
├── speech2video.py
├── text2video.py
├── textimage2video.py
└── utils/
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
├── system_prompt.py
└── utils.py
SYMBOL INDEX (655 symbols across 42 files)
FILE: generate.py
function _validate_args (line 62) | def _validate_args(args):
function _parse_args (line 105) | def _parse_args():
function _init_logging (line 303) | def _init_logging(rank):
function generate (line 315) | def generate(args):
FILE: wan/animate.py
class WanAnimate (line 36) | class WanAnimate:
method __init__ (line 38) | def __init__(
method _configure_model (line 144) | def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
method inputs_padding (line 201) | def inputs_padding(self, array, target_len):
method get_valid_len (line 215) | def get_valid_len(self, real_len, clip_len=81, overlap=1):
method get_i2v_mask (line 226) | def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_val...
method padding_resize (line 237) | def padding_resize(self, img_ori, height=512, width=512, padding_color...
method prepare_source (line 269) | def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
method prepare_source_for_replace (line 284) | def prepare_source_for_replace(self, src_bg_path, src_mask_path):
method generate (line 297) | def generate(
FILE: wan/distributed/fsdp.py
function shard_model (line 12) | def shard_model(
function free_model (line 39) | def free_model(model):
FILE: wan/distributed/sequence_parallel.py
function pad_freqs (line 10) | def pad_freqs(original_tensor, target_len):
function rope_apply (line 24) | def rope_apply(x, grid_sizes, freqs):
function sp_dit_forward (line 64) | def sp_dit_forward(
function sp_attn_forward (line 147) | def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bf...
FILE: wan/distributed/ulysses.py
function distributed_attention (line 9) | def distributed_attention(
FILE: wan/distributed/util.py
function init_distributed_group (line 6) | def init_distributed_group():
function get_rank (line 13) | def get_rank():
function get_world_size (line 17) | def get_world_size():
function all_to_all (line 21) | def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
function all_gather (line 34) | def all_gather(tensor):
function gather_forward (line 43) | def gather_forward(input, dim):
FILE: wan/image2video.py
class WanI2V (line 33) | class WanI2V:
method __init__ (line 35) | def __init__(
method _configure_model (line 128) | def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
method _prepare_model_for_timestep (line 172) | def _prepare_model_for_timestep(self, t, boundary, offload_model):
method generate (line 206) | def generate(self,
FILE: wan/modules/animate/animate_utils.py
function get_loraconfig (line 7) | def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="...
class TensorList (line 23) | class TensorList(object):
method __init__ (line 25) | def __init__(self, tensors):
method to (line 36) | def to(self, *args, **kwargs):
method size (line 39) | def size(self, dim):
method pow (line 43) | def pow(self, *args, **kwargs):
method squeeze (line 46) | def squeeze(self, dim):
method type (line 52) | def type(self, *args, **kwargs):
method type_as (line 55) | def type_as(self, other):
method dtype (line 63) | def dtype(self):
method device (line 67) | def device(self):
method ndim (line 71) | def ndim(self):
method __getitem__ (line 74) | def __getitem__(self, index):
method __len__ (line 77) | def __len__(self):
method __add__ (line 80) | def __add__(self, other):
method __radd__ (line 83) | def __radd__(self, other):
method __sub__ (line 86) | def __sub__(self, other):
method __rsub__ (line 89) | def __rsub__(self, other):
method __mul__ (line 92) | def __mul__(self, other):
method __rmul__ (line 95) | def __rmul__(self, other):
method __floordiv__ (line 98) | def __floordiv__(self, other):
method __truediv__ (line 101) | def __truediv__(self, other):
method __rfloordiv__ (line 104) | def __rfloordiv__(self, other):
method __rtruediv__ (line 107) | def __rtruediv__(self, other):
method __pow__ (line 110) | def __pow__(self, other):
method __rpow__ (line 113) | def __rpow__(self, other):
method __neg__ (line 116) | def __neg__(self):
method __iter__ (line 119) | def __iter__(self):
method __repr__ (line 123) | def __repr__(self):
method _apply (line 126) | def _apply(self, other, op):
FILE: wan/modules/animate/clip.py
function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
class QuickGELU (line 41) | class QuickGELU(nn.Module):
method forward (line 43) | def forward(self, x):
class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
method forward (line 49) | def forward(self, x):
class SelfAttention (line 53) | class SelfAttention(nn.Module):
method __init__ (line 55) | def __init__(self,
method forward (line 74) | def forward(self, x):
class SwiGLU (line 94) | class SwiGLU(nn.Module):
method __init__ (line 96) | def __init__(self, dim, mid_dim):
method forward (line 106) | def forward(self, x):
class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
method __init__ (line 114) | def __init__(self,
method forward (line 146) | def forward(self, x):
class AttentionPool (line 156) | class AttentionPool(nn.Module):
method __init__ (line 158) | def __init__(self,
method forward (line 186) | def forward(self, x):
class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
method __init__ (line 211) | def __init__(self,
method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
method __init__ (line 305) | def __init__(self, **kwargs):
method forward (line 315) | def forward(self, ids):
class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
method __init__ (line 330) | def __init__(self,
method forward (line 406) | def forward(self, imgs, txt_ids):
method param_groups (line 418) | def param_groups(self):
function _clip (line 434) | def _clip(pretrained=False,
function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
class CLIPModel (line 501) | class CLIPModel:
method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
method visual (line 527) | def visual(self, videos):
FILE: wan/modules/animate/face_blocks.py
function attention (line 32) | def attention(
class CausalConv1d (line 112) | class CausalConv1d(nn.Module):
method __init__ (line 114) | def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilatio...
method forward (line 123) | def forward(self, x):
class FaceEncoder (line 129) | class FaceEncoder(nn.Module):
method __init__ (line 130) | def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=...
method forward (line 150) | def forward(self, x):
class RMSNorm (line 180) | class RMSNorm(nn.Module):
method __init__ (line 181) | def __init__(
method _norm (line 207) | def _norm(self, x):
method forward (line 220) | def forward(self, x):
function get_norm_layer (line 237) | def get_norm_layer(norm_layer):
class FaceAdapter (line 255) | class FaceAdapter(nn.Module):
method __init__ (line 256) | def __init__(
method forward (line 284) | def forward(
class FaceBlock (line 297) | class FaceBlock(nn.Module):
method __init__ (line 298) | def __init__(
method forward (line 334) | def forward(
FILE: wan/modules/animate/model_animate.py
class HeadAnimate (line 39) | class HeadAnimate(Head):
method forward (line 41) | def forward(self, x, e):
class WanAnimateSelfAttention (line 54) | class WanAnimateSelfAttention(WanSelfAttention):
method forward (line 56) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanAnimateCrossAttention (line 88) | class WanAnimateCrossAttention(WanSelfAttention):
method __init__ (line 89) | def __init__(
method forward (line 112) | def forward(self, x, context, context_lens):
class WanAnimateAttentionBlock (line 149) | class WanAnimateAttentionBlock(nn.Module):
method __init__ (line 150) | def __init__(self,
method forward (line 188) | def forward(
class MLPProj (line 230) | class MLPProj(torch.nn.Module):
method __init__ (line 231) | def __init__(self, in_dim, out_dim):
method forward (line 242) | def forward(self, image_embeds):
class WanAnimateModel (line 246) | class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
method __init__ (line 250) | def __init__(self,
method after_patch_embedding (line 340) | def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, f...
method after_transformer_block (line 364) | def after_transformer_block(self, block_idx, x, motion_vec, motion_mas...
method forward (line 372) | def forward(
method unpatchify (line 453) | def unpatchify(self, x, grid_sizes):
method init_weights (line 478) | def init_weights(self):
FILE: wan/modules/animate/motion_encoder.py
function custom_qr (line 8) | def custom_qr(input_tensor):
function fused_leaky_relu (line 15) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
function upfirdn2d_native (line 19) | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, ...
function upfirdn2d (line 39) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
function make_kernel (line 43) | def make_kernel(k):
class FusedLeakyReLU (line 51) | class FusedLeakyReLU(nn.Module):
method __init__ (line 52) | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
method forward (line 58) | def forward(self, input):
class Blur (line 63) | class Blur(nn.Module):
method __init__ (line 64) | def __init__(self, kernel, pad, upsample_factor=1):
method forward (line 76) | def forward(self, input):
class ScaledLeakyReLU (line 80) | class ScaledLeakyReLU(nn.Module):
method __init__ (line 81) | def __init__(self, negative_slope=0.2):
method forward (line 86) | def forward(self, input):
class EqualConv2d (line 90) | class EqualConv2d(nn.Module):
method __init__ (line 91) | def __init__(self, in_channel, out_channel, kernel_size, stride=1, pad...
method forward (line 105) | def forward(self, input):
method __repr__ (line 109) | def __repr__(self):
class EqualLinear (line 116) | class EqualLinear(nn.Module):
method __init__ (line 117) | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, ...
method forward (line 132) | def forward(self, input):
method __repr__ (line 142) | def __repr__(self):
class ConvLayer (line 146) | class ConvLayer(nn.Sequential):
method __init__ (line 147) | def __init__(
class ResBlock (line 186) | class ResBlock(nn.Module):
method __init__ (line 187) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
method forward (line 195) | def forward(self, input):
class EncoderApp (line 205) | class EncoderApp(nn.Module):
method __init__ (line 206) | def __init__(self, size, w_dim=512):
method forward (line 235) | def forward(self, x):
class Encoder (line 246) | class Encoder(nn.Module):
method __init__ (line 247) | def __init__(self, size, dim=512, dim_motion=20):
method enc_app (line 261) | def enc_app(self, x):
method enc_motion (line 265) | def enc_motion(self, x):
class Direction (line 271) | class Direction(nn.Module):
method __init__ (line 272) | def __init__(self, motion_dim):
method forward (line 276) | def forward(self, input):
class Synthesis (line 289) | class Synthesis(nn.Module):
method __init__ (line 290) | def __init__(self, motion_dim):
class Generator (line 295) | class Generator(nn.Module):
method __init__ (line 296) | def __init__(self, size, style_dim=512, motion_dim=20):
method get_motion (line 302) | def get_motion(self, img):
FILE: wan/modules/animate/preprocess/human_visualization.py
function draw_handpose (line 14) | def draw_handpose(canvas, keypoints, hand_score_th=0.6):
function draw_handpose_new (line 93) | def draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_scor...
function draw_ellipse_by_2kp (line 175) | def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6):
function split_pose2d_kps_to_aa (line 193) | def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]:
function draw_aapose_by_meta (line 211) | def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_widt...
function draw_aapose_by_meta_new (line 218) | def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickw...
function draw_hand_by_meta (line 226) | def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_...
function draw_aaface_by_meta (line 234) | def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_widt...
function draw_aanose_by_meta (line 242) | def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_widt...
function gen_face_motion_seq (line 250) | def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, sti...
function draw_M (line 255) | def draw_M(
function draw_nose (line 422) | def draw_nose(
function draw_aapose (line 586) | def draw_aapose(
function draw_aapose_new (line 750) | def draw_aapose_new(
function draw_bbox (line 921) | def draw_bbox(img, bbox, color=(255, 0, 0)):
function draw_kp2ds (line 928) | def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None...
function draw_mask (line 1022) | def draw_mask(img, mask, background=0, return_rgba=False):
function draw_pcd (line 1032) | def draw_pcd(pcd_list, save_path=None):
function load_image (line 1051) | def load_image(img, reverse=False):
function draw_skeleten (line 1061) | def draw_skeleten(meta):
function draw_skeleten_with_pncc (line 1085) | def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray:
function draw_face_kp (line 1147) | def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE):
function draw_traj (line 1168) | def draw_traj(metas: List[AAPoseMeta], threshold=0.6):
FILE: wan/modules/animate/preprocess/pose2d.py
class SimpleOnnxInference (line 20) | class SimpleOnnxInference(object):
method __init__ (line 21) | def __init__(self, checkpoint, device='cuda', reverse_input=False, **k...
method __call__ (line 45) | def __call__(self, *args, **kwargs):
method get_output_names (line 49) | def get_output_names(self):
method set_device (line 56) | def set_device(self, device):
class Yolo (line 68) | class Yolo(SimpleOnnxInference):
method __init__ (line 69) | def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, thr...
method preprocess (line 89) | def preprocess(self, input_image):
method postprocess (line 112) | def postprocess(self, output, shape_raw, cat_id=[1]):
method process_results (line 207) | def process_results(self, results, shape_raw, cat_id=[1], single_perso...
method postprocess_threading (line 283) | def postprocess_threading(self, outputs, shape_raw, person_results, i,...
method forward (line 290) | def forward(self, img, shape_raw, **kwargs):
class ViTPose (line 309) | class ViTPose(SimpleOnnxInference):
method __init__ (line 310) | def __init__(self, checkpoint, device='cuda', **kwargs):
method forward (line 313) | def forward(self, img, center, scale, **kwargs):
method preprocess (line 324) | def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1....
class Pose2d (line 346) | class Pose2d:
method __init__ (line 347) | def __init__(self, checkpoint, detector_checkpoint=None, device='cuda'...
method load_images (line 357) | def load_images(self, inputs):
method __call__ (line 396) | def __call__(
FILE: wan/modules/animate/preprocess/pose2d_utils.py
function box_convert_simple (line 9) | def box_convert_simple(box, convert_type='xyxy2xywh'):
function read_img (line 19) | def read_img(image, convert='RGB', check_exist=False):
class AAPoseMeta (line 39) | class AAPoseMeta:
method __init__ (line 40) | def __init__(self, meta=None, kp2ds=None):
method is_valid (line 60) | def is_valid(self, kp, p, threshold):
method get_bbox (line 67) | def get_bbox(self, kp, kp_p, threshold=0.5):
method crop (line 75) | def crop(self, x0, y0, x1, y1):
method resize (line 85) | def resize(self, width, height):
method get_kps_body_with_p (line 98) | def get_kps_body_with_p(self, normalize=False):
method from_kps_face (line 106) | def from_kps_face(kps_face: np.ndarray, height: int, width: int):
method from_kps_body (line 119) | def from_kps_body(kps_body: np.ndarray, height: int, width: int):
method from_humanapi_meta (line 128) | def from_humanapi_meta(meta):
method load_from_meta (line 144) | def load_from_meta(self, meta, norm_body=True, norm_hand=False):
method load_from_kp2ds (line 170) | def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int):
method from_dwpose (line 199) | def from_dwpose(dwpose_det_res, height, width):
method save_json (line 219) | def save_json(self):
method draw_aapose (line 222) | def draw_aapose(self, img, threshold=0.5, stick_width_norm=200, draw_h...
method translate (line 227) | def translate(self, x0, y0):
method scale (line 234) | def scale(self, sx, sy):
method padding_resize2 (line 241) | def padding_resize2(self, height=512, width=512):
function transform_preds (line 279) | def transform_preds(coords, center, scale, output_size, use_udp=False):
function _calc_distances (line 326) | def _calc_distances(preds, targets, mask, normalize):
function _distance_acc (line 358) | def _distance_acc(distances, thr=0.5):
function _get_max_preds (line 379) | def _get_max_preds(heatmaps):
function _get_max_preds_3d (line 414) | def _get_max_preds_3d(heatmaps):
function pose_pck_accuracy (line 452) | def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None):
function keypoint_pck_accuracy (line 495) | def keypoint_pck_accuracy(pred, gt, mask, thr, normalize):
function keypoint_auc (line 534) | def keypoint_auc(pred, gt, mask, normalize, num_step=20):
function keypoint_nme (line 566) | def keypoint_nme(pred, gt, mask, normalize_factor):
function keypoint_epe (line 589) | def keypoint_epe(pred, gt, mask):
function _taylor (line 614) | def _taylor(heatmap, coord):
function post_dark_udp (line 651) | def post_dark_udp(coords, batch_heatmaps, kernel=3):
function _gaussian_blur (line 715) | def _gaussian_blur(heatmaps, kernel=11):
function keypoints_from_regression (line 757) | def keypoints_from_regression(regression_preds, center, scale, img_size):
function keypoints_from_heatmaps (line 790) | def keypoints_from_heatmaps(heatmaps,
function keypoints_from_heatmaps3d (line 941) | def keypoints_from_heatmaps3d(heatmaps, center, scale):
function multilabel_classification_accuracy (line 974) | def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
function get_transform (line 1004) | def get_transform(center, scale, res, rot=0):
function transform (line 1034) | def transform(pt, center, scale, res, invert=0, rot=0):
function bbox_from_detector (line 1044) | def bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25):
function crop (line 1069) | def crop(img, center, scale, res):
function split_kp2ds_for_aa (line 1102) | def split_kp2ds_for_aa(kp2ds, ret_face=False):
function load_pose_metas_from_kp2ds_seq_list (line 1111) | def load_pose_metas_from_kp2ds_seq_list(kp2ds_seq, width, height):
function load_pose_metas_from_kp2ds_seq (line 1137) | def load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height):
FILE: wan/modules/animate/preprocess/preprocess_data.py
function _parse_args (line 7) | def _parse_args():
FILE: wan/modules/animate/preprocess/process_pipepline.py
class ProcessPipeline (line 28) | class ProcessPipeline():
method __init__ (line 29) | def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_ch...
method __call__ (line 38) | def __call__(self, video_path, refer_image_path, output_path, resoluti...
method get_editing_prompts (line 237) | def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):
method get_mask (line 280) | def get_mask(self, frames, th_step, kp2ds_all):
method convert_list_to_array (line 345) | def convert_list_to_array(self, metas):
FILE: wan/modules/animate/preprocess/retarget_pose.py
class Keypoint (line 51) | class Keypoint(NamedTuple):
function get_length (line 60) | def get_length(skeleton, limb):
function get_handpose_meta (line 80) | def get_handpose_meta(keypoints, delta, src_H, src_W):
function deal_hand_keypoints (line 106) | def deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th = 0.5):
function get_scaled_pose (line 159) | def get_scaled_pose(canvas, src_canvas, keypoints, keypoints_hand, bone_...
function rescale_skeleton (line 309) | def rescale_skeleton(H, W, keypoints, bone_ratio_list):
function fix_lack_keypoints_use_sym (line 369) | def fix_lack_keypoints_use_sym(skeleton):
function rescale_shorten_skeleton (line 450) | def rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list):
function check_full_body (line 481) | def check_full_body(keypoints, threshold = 0.4):
function check_full_body_both (line 501) | def check_full_body_both(flag1, flag2):
function write_to_poses (line 520) | def write_to_poses(data_to_json, none_idx, dst_shape, bone_ratio_list, d...
function calculate_scale_ratio (line 551) | def calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag):
function retarget_pose (line 571) | def retarget_pose(src_skeleton, dst_skeleton, all_src_skeleton, src_skel...
function get_retarget_pose (line 760) | def get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, t...
FILE: wan/modules/animate/preprocess/sam_utils.py
function _load_img_v2_as_tensor (line 23) | def _load_img_v2_as_tensor(img, image_size):
function load_video_frames (line 34) | def load_video_frames(
function load_video_frames_v2 (line 89) | def load_video_frames_v2(
function build_sam2_video_predictor (line 122) | def build_sam2_video_predictor(
FILE: wan/modules/animate/preprocess/utils.py
function get_mask_boxes (line 8) | def get_mask_boxes(mask):
function get_aug_mask (line 25) | def get_aug_mask(body_mask, w_len=10, h_len=20):
function get_mask_body_img (line 44) | def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1):
function get_face_bboxes (line 52) | def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug):
function calculate_new_size (line 87) | def calculate_new_size(orig_w, orig_h, target_area, divisor=64):
function resize_by_area (line 136) | def resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=6...
function padding_resize (line 158) | def padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, ...
function get_frame_indices (line 191) | def get_frame_indices(frame_num, video_fps, clip_length, train_fps):
function get_face_bboxes (line 201) | def get_face_bboxes(kp2ds, scale, image_shape):
FILE: wan/modules/animate/preprocess/video_predictor.py
class SAM2VideoPredictor (line 14) | class SAM2VideoPredictor(_SAM2VideoPredictor):
method __init__ (line 15) | def __init__(self, *args, **kwargs):
method init_state (line 20) | def init_state(
method init_state_v2 (line 90) | def init_state_v2(
FILE: wan/modules/animate/xlm_roberta.py
class SelfAttention (line 10) | class SelfAttention(nn.Module):
method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
method forward (line 27) | def forward(self, x, mask):
class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
method forward (line 66) | def forward(self, x, mask):
class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
method __init__ (line 81) | def __init__(self,
method forward (line 118) | def forward(self, ids):
function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,
FILE: wan/modules/attention.py
function flash_attention (line 24) | def flash_attention(
function attention (line 133) | def attention(
FILE: wan/modules/model.py
function sinusoidal_embedding_1d (line 14) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 28) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 39) | def rope_apply(x, grid_sizes, freqs):
class WanRMSNorm (line 69) | class WanRMSNorm(nn.Module):
method __init__ (line 71) | def __init__(self, dim, eps=1e-5):
method forward (line 77) | def forward(self, x):
method _norm (line 84) | def _norm(self, x):
class WanLayerNorm (line 88) | class WanLayerNorm(nn.LayerNorm):
method __init__ (line 90) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 93) | def forward(self, x):
class WanSelfAttention (line 101) | class WanSelfAttention(nn.Module):
method __init__ (line 103) | def __init__(self,
method forward (line 126) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanCrossAttention (line 158) | class WanCrossAttention(WanSelfAttention):
method forward (line 160) | def forward(self, x, context, context_lens):
class WanAttentionBlock (line 183) | class WanAttentionBlock(nn.Module):
method __init__ (line 185) | def __init__(self,
method forward (line 219) | def forward(
class Head (line 262) | class Head(nn.Module):
method __init__ (line 264) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 279) | def forward(self, x, e):
class WanModel (line 294) | class WanModel(ModelMixin, ConfigMixin):
method __init__ (line 305) | def __init__(self,
method forward (line 410) | def forward(
method unpatchify (line 499) | def unpatchify(self, x, grid_sizes):
method init_weights (line 524) | def init_weights(self):
FILE: wan/modules/s2v/audio_encoder.py
function get_sample_indices (line 11) | def get_sample_indices(original_fps,
function linear_interpolation (line 38) | def linear_interpolation(features, input_fps, output_fps, output_len=None):
class AudioEncoder (line 55) | class AudioEncoder():
method __init__ (line 57) | def __init__(self, device='cpu', model_id="facebook/wav2vec2-base-960h"):
method extract_audio_feat (line 66) | def extract_audio_feat(self,
method get_audio_embed_bucket (line 91) | def get_audio_embed_bucket(self,
method get_audio_embed_bucket_fps (line 136) | def get_audio_embed_bucket_fps(self,
FILE: wan/modules/s2v/audio_utils.py
class CausalAudioEncoder (line 14) | class CausalAudioEncoder(nn.Module):
method __init__ (line 16) | def __init__(self,
method forward (line 34) | def forward(self, features):
class AudioCrossAttention (line 47) | class AudioCrossAttention(WanCrossAttention):
method __init__ (line 49) | def __init__(self, *args, **kwargs):
class AudioInjector_WAN (line 53) | class AudioInjector_WAN(nn.Module):
method __init__ (line 55) | def __init__(self,
FILE: wan/modules/s2v/auxi_blocks.py
function attention (line 35) | def attention(
class CausalConv1d (line 121) | class CausalConv1d(nn.Module):
method __init__ (line 123) | def __init__(self,
method forward (line 145) | def forward(self, x):
class MotionEncoder_tc (line 150) | class MotionEncoder_tc(nn.Module):
method __init__ (line 152) | def __init__(self,
method forward (line 199) | def forward(self, x):
FILE: wan/modules/s2v/model_s2v.py
function zero_module (line 35) | def zero_module(module):
function torch_dfs (line 44) | def torch_dfs(model: nn.Module, parent_name='root'):
function rope_apply (line 62) | def rope_apply(x, grid_sizes, freqs, start=None):
function rope_apply_usp (line 80) | def rope_apply_usp(x, grid_sizes, freqs):
function sp_attn_forward_s2v (line 98) | def sp_attn_forward_s2v(self,
class Head_S2V (line 135) | class Head_S2V(Head):
method forward (line 137) | def forward(self, x, e):
class WanS2VSelfAttention (line 150) | class WanS2VSelfAttention(WanSelfAttention):
method forward (line 152) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanS2VAttentionBlock (line 184) | class WanS2VAttentionBlock(WanAttentionBlock):
method __init__ (line 186) | def __init__(self,
method forward (line 199) | def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_...
class WanModel_S2V (line 247) | class WanModel_S2V(ModelMixin, ConfigMixin):
method __init__ (line 255) | def __init__(
method zero_init_weights (line 441) | def zero_init_weights(self):
method process_motion (line 455) | def process_motion(self, motion_latents, drop_motion_frames=False):
method process_motion_frame_pack (line 484) | def process_motion_frame_pack(self,
method process_motion_transformer_motioner (line 496) | def process_motion_transformer_motioner(self,
method inject_motion (line 561) | def inject_motion(self,
method after_transformer_block (line 601) | def after_transformer_block(self, block_idx, hidden_states):
method forward (line 650) | def forward(
method unpatchify (line 859) | def unpatchify(self, x, grid_sizes):
method init_weights (line 884) | def init_weights(self):
FILE: wan/modules/s2v/motioner.py
function sinusoidal_embedding_1d (line 17) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 31) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 42) | def rope_apply(x, grid_sizes, freqs, start=None):
class RMSNorm (line 114) | class RMSNorm(nn.Module):
method __init__ (line 116) | def __init__(self, dim, eps=1e-5):
method forward (line 122) | def forward(self, x):
method _norm (line 125) | def _norm(self, x):
class LayerNorm (line 129) | class LayerNorm(nn.LayerNorm):
method __init__ (line 131) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 134) | def forward(self, x):
class SelfAttention (line 138) | class SelfAttention(nn.Module):
method __init__ (line 140) | def __init__(self,
method forward (line 163) | def forward(self, x, seq_lens, grid_sizes, freqs):
class SwinSelfAttention (line 188) | class SwinSelfAttention(SelfAttention):
method forward (line 190) | def forward(self, x, seq_lens, grid_sizes, freqs):
class CasualSelfAttention (line 246) | class CasualSelfAttention(SelfAttention):
method forward (line 248) | def forward(self, x, seq_lens, grid_sizes, freqs):
class MotionerAttentionBlock (line 328) | class MotionerAttentionBlock(nn.Module):
method __init__ (line 330) | def __init__(self,
method forward (line 365) | def forward(
class Head (line 380) | class Head(nn.Module):
method __init__ (line 382) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 394) | def forward(self, x):
class MotionerTransformers (line 399) | class MotionerTransformers(nn.Module, PeftAdapterMixin):
method __init__ (line 401) | def __init__(
method after_patch_embedding (line 492) | def after_patch_embedding(self, x):
method forward (line 495) | def forward(
method unpatchify (line 621) | def unpatchify(self, x, grid_sizes):
method init_weights (line 631) | def init_weights(self):
class FramePackMotioner (line 643) | class FramePackMotioner(nn.Module):
method __init__ (line 645) | def __init__(
method forward (line 679) | def forward(self, motion_latents, add_last_motion=2):
function sample_indices (line 765) | def sample_indices(N, stride, expand_ratio, c):
FILE: wan/modules/s2v/s2v_utils.py
function rope_precompute (line 6) | def rope_precompute(x, grid_sizes, freqs, start=None):
FILE: wan/modules/t5.py
function fp16_clamp (line 20) | def fp16_clamp(x):
function init_weights (line 27) | def init_weights(m):
class GELU (line 46) | class GELU(nn.Module):
method forward (line 48) | def forward(self, x):
class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
method forward (line 61) | def forward(self, x):
class T5Attention (line 69) | class T5Attention(nn.Module):
method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
method forward (line 136) | def forward(self, x):
class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
method __init__ (line 146) | def __init__(self,
method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
method __init__ (line 180) | def __init__(self,
method forward (line 206) | def forward(self,
class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
method forward (line 233) | def forward(self, lq, lk):
method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
class T5Encoder (line 267) | class T5Encoder(nn.Module):
method __init__ (line 269) | def __init__(self,
method forward (line 303) | def forward(self, ids, mask=None):
class T5Decoder (line 315) | class T5Decoder(nn.Module):
method __init__ (line 317) | def __init__(self,
method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
class T5Model (line 372) | class T5Model(nn.Module):
method __init__ (line 374) | def __init__(self,
method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
function _t5 (line 415) | def _t5(name,
function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
class T5EncoderModel (line 472) | class T5EncoderModel:
method __init__ (line 474) | def __init__(
method __call__ (line 506) | def __call__(self, texts, device):
FILE: wan/modules/tokenizers.py
function basic_clean (line 12) | def basic_clean(text):
function whitespace_clean (line 18) | def whitespace_clean(text):
function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
method __call__ (line 49) | def __call__(self, sequence, **kwargs):
method _clean (line 75) | def _clean(self, text):
FILE: wan/modules/vae2_1.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 28) | def forward(self, x, cache_x=None):
class RMS_norm (line 39) | class RMS_norm(nn.Module):
method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 51) | def forward(self, x):
class Upsample (line 57) | class Upsample(nn.Upsample):
method forward (line 59) | def forward(self, x):
class Resample (line 66) | class Resample(nn.Module):
method __init__ (line 68) | def __init__(self, dim, mode):
method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 162) | def init_weight(self, conv):
method init_weight2 (line 174) | def init_weight2(self, conv):
class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
method __init__ (line 228) | def __init__(self, dim):
method forward (line 240) | def forward(self, x):
class Encoder3d (line 265) | class Encoder3d(nn.Module):
method __init__ (line 267) | def __init__(self,
method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 369) | class Decoder3d(nn.Module):
method __init__ (line 371) | def __init__(self,
method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
function count_conv3d (line 475) | def count_conv3d(model):
class WanVAE_ (line 483) | class WanVAE_(nn.Module):
method __init__ (line 485) | def __init__(self,
method forward (line 510) | def forward(self, x):
method encode (line 516) | def encode(self, x, scale):
method decode (line 544) | def decode(self, z, scale):
method reparameterize (line 570) | def reparameterize(self, mu, log_var):
method sample (line 575) | def sample(self, imgs, deterministic=False):
method clear_cache (line 582) | def clear_cache(self):
function _video_vae (line 592) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
class Wan2_1_VAE (line 619) | class Wan2_1_VAE:
method __init__ (line 621) | def __init__(self,
method encode (line 647) | def encode(self, videos):
method decode (line 657) | def decode(self, zs):
FILE: wan/modules/vae2_2.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 34) | def forward(self, x, cache_x=None):
class RMS_norm (line 45) | class RMS_norm(nn.Module):
method __init__ (line 47) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 57) | def forward(self, x):
class Upsample (line 62) | class Upsample(nn.Upsample):
method forward (line 64) | def forward(self, x):
class Resample (line 71) | class Resample(nn.Module):
method __init__ (line 73) | def __init__(self, dim, mode):
method forward (line 112) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 171) | def init_weight(self, conv):
method init_weight2 (line 182) | def init_weight2(self, conv):
class ResidualBlock (line 193) | class ResidualBlock(nn.Module):
method __init__ (line 195) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 214) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 238) | class AttentionBlock(nn.Module):
method __init__ (line 243) | def __init__(self, dim):
method forward (line 255) | def forward(self, x):
function patchify (line 280) | def patchify(x, patch_size):
function unpatchify (line 299) | def unpatchify(x, patch_size):
class AvgDown3D (line 316) | class AvgDown3D(nn.Module):
method __init__ (line 318) | def __init__(
method forward (line 335) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class DupUp3D (line 370) | class DupUp3D(nn.Module):
method __init__ (line 372) | def __init__(
method forward (line 390) | def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
class Down_ResidualBlock (line 415) | class Down_ResidualBlock(nn.Module):
method __init__ (line 417) | def __init__(self,
method forward (line 447) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Up_ResidualBlock (line 455) | class Up_ResidualBlock(nn.Module):
method __init__ (line 457) | def __init__(self,
method forward (line 489) | def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
class Encoder3d (line 500) | class Encoder3d(nn.Module):
method __init__ (line 502) | def __init__(
method forward (line 559) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 616) | class Decoder3d(nn.Module):
method __init__ (line 618) | def __init__(
method forward (line 672) | def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
function count_conv3d (line 726) | def count_conv3d(model):
class WanVAE_ (line 734) | class WanVAE_(nn.Module):
method __init__ (line 736) | def __init__(
method forward (line 778) | def forward(self, x, scale=[0, 1]):
method encode (line 783) | def encode(self, x, scale):
method decode (line 812) | def decode(self, z, scale):
method reparameterize (line 841) | def reparameterize(self, mu, log_var):
method sample (line 846) | def sample(self, imgs, deterministic=False):
method clear_cache (line 853) | def clear_cache(self):
function _video_vae (line 863) | def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **...
class Wan2_2_VAE (line 888) | class Wan2_2_VAE:
method __init__ (line 890) | def __init__(
method encode (line 1024) | def encode(self, videos):
method decode (line 1038) | def decode(self, zs):
FILE: wan/speech2video.py
function load_safetensors (line 39) | def load_safetensors(path):
class WanS2V (line 47) | class WanS2V:
method __init__ (line 49) | def __init__(
method _configure_model (line 146) | def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
method get_size_less_than_area (line 189) | def get_size_less_than_area(self,
method prepare_default_cond_input (line 252) | def prepare_default_cond_input(self,
method encode_audio (line 283) | def encode_audio(self, audio_path, infer_frames):
method read_last_n_frames (line 297) | def read_last_n_frames(self,
method load_pose_cond (line 336) | def load_pose_cond(self, pose_video, num_repeat, infer_frames, size):
method get_gen_size (line 378) | def get_gen_size(self, size, max_area, ref_image_path, pre_video_path):
method generate (line 392) | def generate(
method tts (line 681) | def tts(self, tts_prompt_audio, tts_prompt_text, tts_text):
method load_tts (line 697) | def load_tts(self):
FILE: wan/text2video.py
class WanT2V (line 31) | class WanT2V:
method __init__ (line 33) | def __init__(
method _configure_model (line 125) | def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
method _prepare_model_for_timestep (line 169) | def _prepare_model_for_timestep(self, t, boundary, offload_model):
method generate (line 203) | def generate(self,
FILE: wan/textimage2video.py
class WanTI2V (line 34) | class WanTI2V:
method __init__ (line 36) | def __init__(
method _configure_model (line 118) | def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
method generate (line 162) | def generate(self,
method t2v (line 239) | def t2v(self,
method i2v (line 413) | def i2v(self,
FILE: wan/utils/fm_solvers.py
function get_sampling_sigmas (line 24) | def get_sampling_sigmas(sampling_steps, shift):
function retrieve_timesteps (line 31) | def retrieve_timesteps(
class FlowDPMSolverMultistepScheduler (line 71) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 131) | def __init__(
method step_index (line 204) | def step_index(self):
method begin_index (line 211) | def begin_index(self):
method set_begin_index (line 218) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 228) | def set_timesteps(
method _threshold_sample (line 294) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 332) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 335) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 339) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 343) | def convert_model_output(
method dpm_solver_first_order_update (line 417) | def dpm_solver_first_order_update(
method multistep_dpm_solver_second_order_update (line 488) | def multistep_dpm_solver_second_order_update(
method multistep_dpm_solver_third_order_update (line 598) | def multistep_dpm_solver_third_order_update(
method index_for_timestep (line 681) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 695) | def _init_step_index(self, timestep):
method step (line 708) | def step(
method scale_model_input (line 802) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 817) | def add_noise(
method __len__ (line 858) | def __len__(self):
FILE: wan/utils/fm_solvers_unipc.py
class FlowUniPCMultistepScheduler (line 22) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 79) | def __init__(
method step_index (line 137) | def step_index(self):
method begin_index (line 144) | def begin_index(self):
method set_begin_index (line 151) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 162) | def set_timesteps(
method _threshold_sample (line 232) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 271) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 274) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 278) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 281) | def convert_model_output(
method multistep_uni_p_bh_update (line 352) | def multistep_uni_p_bh_update(
method multistep_uni_c_bh_update (line 488) | def multistep_uni_c_bh_update(
method index_for_timestep (line 630) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 645) | def _init_step_index(self, timestep):
method step (line 657) | def step(self,
method scale_model_input (line 743) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 760) | def add_noise(
method __len__ (line 801) | def __len__(self):
FILE: wan/utils/prompt_extend.py
class PromptOutput (line 53) | class PromptOutput(object):
method add_custom_field (line 60) | def add_custom_field(self, key: str, value) -> None:
class PromptExpander (line 64) | class PromptExpander:
method __init__ (line 66) | def __init__(self, model_name, task, is_vl=False, device=0, **kwargs):
method extend_with_img (line 72) | def extend_with_img(self,
method extend (line 81) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method decide_system_prompt (line 84) | def decide_system_prompt(self, tar_lang="zh", prompt=None):
method __call__ (line 95) | def __call__(self,
class DashScopePromptExpander (line 117) | class DashScopePromptExpander(PromptExpander):
method __init__ (line 119) | def __init__(self,
method extend (line 158) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 194) | def extend_with_img(self,
class QwenPromptExpander (line 262) | class QwenPromptExpander(PromptExpander):
method __init__ (line 271) | def __init__(self,
method extend (line 337) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 368) | def extend_with_img(self,
function test (line 441) | def test(method,
FILE: wan/utils/qwen_vl_utils.py
function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
function smart_resize (line 54) | def smart_resize(height: int,
function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
function smart_nframes (line 133) | def smart_nframes(
function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
function is_decord_available (line 215) | def is_decord_available() -> bool:
function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
function fetch_video (line 274) | def fetch_video(
function extract_vision_info (line 328) | def extract_vision_info(
function process_vision_info (line 344) | def process_vision_info(
FILE: wan/utils/utils.py
function rand_name (line 17) | def rand_name(length=8, suffix=''):
function merge_video_audio (line 26) | def merge_video_audio(video_path: str, audio_path: str):
function save_video (line 90) | def save_video(tensor,
function save_image (line 123) | def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-...
function str2bool (line 145) | def str2bool(v):
function masks_like (line 172) | def masks_like(tensor, zero=False, generator=None, p=0.2):
function best_output_size (line 202) | def best_output_size(w, h, dw, dh, expected_area):
function download_cosyvoice_repo (line 228) | def download_cosyvoice_repo(repo_path):
function download_cosyvoice_model (line 236) | def download_cosyvoice_model(model_name, model_path):
Condensed preview — 69 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (795K chars).
[
{
"path": ".gitignore",
"chars": 80,
"preview": "__pycache__/\n.DS_Store\n.vscode*\ntmp_examples*\nnew_checkpoint*\nbatch_test*\nnohup*"
},
{
"path": "INSTALL.md",
"chars": 1158,
"preview": "# Installation Guide\n\n## Install with pip\n\n```bash\npip install .\npip install .[dev] # Installe aussi les outils de dev\n"
},
{
"path": "LICENSE.txt",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "Makefile",
"chars": 80,
"preview": ".PHONY: format\n\nformat:\n\tisort generate.py wan\n\tyapf -i -r *.py generate.py wan\n"
},
{
"path": "README.md",
"chars": 35036,
"preview": "# Wan2.2\n\n<p align=\"center\">\n <img src=\"assets/logo.png\" width=\"400\"/>\n<p>\n\n<p align=\"center\">\n 💜 <a href=\"https:/"
},
{
"path": "generate.py",
"chars": 20392,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport logging\nimport os\nimport"
},
{
"path": "pyproject.toml",
"chars": 1318,
"preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"wan\"\nversion ="
},
{
"path": "requirements.txt",
"chars": 240,
"preview": "torch>=2.4.0\ntorchvision>=0.19.0\ntorchaudio\nopencv-python>=4.9.0.80\ndiffusers>=0.31.0\ntransformers>=4.49.0,<=4.51.3\ntoke"
},
{
"path": "requirements_animate.txt",
"chars": 167,
"preview": "decord\npeft\nonnxruntime\npandas\nmatplotlib\n-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c"
},
{
"path": "requirements_s2v.txt",
"chars": 173,
"preview": "openai-whisper\nHyperPyYAML\nonnxruntime\ninflect\nwetext\nomegaconf\nconformer\nhydra-core\nlightning\nrich\ngdown\nmatplotlib\nwge"
},
{
"path": "tests/README.md",
"chars": 197,
"preview": "\nPut all your models (Wan2.2-T2V-A14B, Wan2.2-I2V-A14B, Wan2.2-TI2V-5B) in a folder and specify the max GPU number you w"
},
{
"path": "tests/test.sh",
"chars": 5774,
"preview": "#!/bin/bash\nset -x\n\nunset NCCL_DEBUG\n\nif [ \"$#\" -eq 2 ]; then\n MODEL_DIR=$(realpath \"$1\")\n GPUS=$2\nelse\n echo \"Usage:"
},
{
"path": "wan/__init__.py",
"chars": 281,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom . import configs, distributed, modules\nfro"
},
{
"path": "wan/animate.py",
"chars": 27456,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\nimport math\nimport os\nimport cv2"
},
{
"path": "wan/configs/__init__.py",
"chars": 1327,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport copy\nimport os\n\nos.environ['TOKENIZERS_P"
},
{
"path": "wan/configs/shared_config.py",
"chars": 679,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n#--"
},
{
"path": "wan/configs/wan_animate_14B.py",
"chars": 1270,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/configs/wan_i2v_A14B.py",
"chars": 1030,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
},
{
"path": "wan/configs/wan_s2v_14B.py",
"chars": 1947,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/configs/wan_t2v_A14B.py",
"chars": 1023,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/configs/wan_ti2v_5B.py",
"chars": 891,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/distributed/__init__.py",
"chars": 73,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n"
},
{
"path": "wan/distributed/fsdp.py",
"chars": 1379,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nfrom functools import partial\n\nimport"
},
{
"path": "wan/distributed/sequence_parallel.py",
"chars": 5189,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\n\nfrom"
},
{
"path": "wan/distributed/ulysses.py",
"chars": 1322,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.distributed as dist\n\n"
},
{
"path": "wan/distributed/util.py",
"chars": 1379,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.distributed as dist\n\n"
},
{
"path": "wan/image2video.py",
"chars": 16479,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/modules/__init__.py",
"chars": 498,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .attention import flash_attention\nfrom .mo"
},
{
"path": "wan/modules/animate/__init__.py",
"chars": 186,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .model_animate import WanAnimateModel\nfrom"
},
{
"path": "wan/modules/animate/animate_utils.py",
"chars": 4534,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport numbers\nfrom peft import Lo"
},
{
"path": "wan/modules/animate/clip.py",
"chars": 16849,
"preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
},
{
"path": "wan/modules/animate/face_blocks.py",
"chars": 12202,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom torch import nn\nimport torch\nfrom typing i"
},
{
"path": "wan/modules/animate/model_animate.py",
"chars": 15898,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nimport types\nfrom copy import deepc"
},
{
"path": "wan/modules/animate/motion_encoder.py",
"chars": 8096,
"preview": "# Modified from ``https://github.com/wyhsirius/LIA``\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights rese"
},
{
"path": "wan/modules/animate/preprocess/UserGuider.md",
"chars": 4806,
"preview": "# Wan-animate Preprocessing User Guider\n\n## 1. Introductions\n\n\nWan-animate offers two generation modes: `animation` and "
},
{
"path": "wan/modules/animate/preprocess/__init__.py",
"chars": 167,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .process_pipepline import ProcessPipeline\n"
},
{
"path": "wan/modules/animate/preprocess/human_visualization.py",
"chars": 44228,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nimport time\nimport math\nim"
},
{
"path": "wan/modules/animate/preprocess/pose2d.py",
"chars": 17641,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nfrom typing import Union, "
},
{
"path": "wan/modules/animate/preprocess/pose2d_utils.py",
"chars": 41786,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport warnings\nimport cv2\nimport numpy as np\nf"
},
{
"path": "wan/modules/animate/preprocess/preprocess_data.py",
"chars": 4623,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport argparse\nfrom process_pipeplin"
},
{
"path": "wan/modules/animate/preprocess/process_pipepline.py",
"chars": 17839,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport numpy as np\nimport shutil\nimpo"
},
{
"path": "wan/modules/animate/preprocess/retarget_pose.py",
"chars": 36891,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nimport numpy as np\nimport "
},
{
"path": "wan/modules/animate/preprocess/sam_utils.py",
"chars": 5916,
"preview": "# Copyright (c) 2025. Your modifications here.\n# This file wraps and extends sam2.utils.misc for custom modifications.\n\n"
},
{
"path": "wan/modules/animate/preprocess/utils.py",
"chars": 7263,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nimport math\nimport random\n"
},
{
"path": "wan/modules/animate/preprocess/video_predictor.py",
"chars": 8314,
"preview": "# Copyright (c) 2025. Your modifications here.\n# A wrapper for sam2 functions\nfrom collections import OrderedDict\nimport"
},
{
"path": "wan/modules/animate/xlm_roberta.py",
"chars": 4864,
"preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
},
{
"path": "wan/modules/attention.py",
"chars": 5435,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n import flash_attn_interf"
},
{
"path": "wan/modules/model.py",
"chars": 18097,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn"
},
{
"path": "wan/modules/s2v/__init__.py",
"chars": 193,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .audio_encoder import AudioEncoder\nfrom .m"
},
{
"path": "wan/modules/s2v/audio_encoder.py",
"chars": 7299,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport librosa\nimport numpy as np\n"
},
{
"path": "wan/modules/s2v/audio_utils.py",
"chars": 3544,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nfrom typing import Tuple, Union\n\nim"
},
{
"path": "wan/modules/s2v/auxi_blocks.py",
"chars": 7923,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport importlib.metadata\nimport math\nfrom typi"
},
{
"path": "wan/modules/s2v/model_s2v.py",
"chars": 34511,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nimport types\nfrom copy import deepc"
},
{
"path": "wan/modules/s2v/motioner.py",
"chars": 28434,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nfrom typing import Any, Dict, List,"
},
{
"path": "wan/modules/s2v/s2v_utils.py",
"chars": 3139,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport numpy as np\nimport torch\n\n\ndef rope_prec"
},
{
"path": "wan/modules/t5.py",
"chars": 16910,
"preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
},
{
"path": "wan/modules/tokenizers.py",
"chars": 2431,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
},
{
"path": "wan/modules/vae2_1.py",
"chars": 23143,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
},
{
"path": "wan/modules/vae2_2.py",
"chars": 31846,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
},
{
"path": "wan/speech2video.py",
"chars": 29064,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/text2video.py",
"chars": 14679,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/textimage2video.py",
"chars": 24031,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/utils/__init__.py",
"chars": 402,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .fm_solvers import (\n FlowDPMSolverMult"
},
{
"path": "wan/utils/fm_solvers.py",
"chars": 40142,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
},
{
"path": "wan/utils/fm_solvers_unipc.py",
"chars": 32557,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
},
{
"path": "wan/utils/prompt_extend.py",
"chars": 19739,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport logging\nimport math\nimport o"
},
{
"path": "wan/utils/qwen_vl_utils.py",
"chars": 13060,
"preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
},
{
"path": "wan/utils/system_prompt.py",
"chars": 12292,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n\nT2V_A14B_ZH_SYS_PROMPT = \\\n''' 你是一位电影导演,旨在为用户输"
},
{
"path": "wan/utils/utils.py",
"chars": 7274,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport logging\n"
}
]
About this extraction
This page contains the full source code of the Wan-Video/Wan2.2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 69 files (748.5 KB), approximately 196.3k tokens, and a symbol index with 655 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.