Full Code of bytedance/ATI for AI

main 1a002caf7bb5 cached
59 files
461.8 KB
117.2k tokens
360 symbols
1 requests
Download .txt
Showing preview only (490K chars total). Download the full file or copy to clipboard to get everything.
Repository: bytedance/ATI
Branch: main
Commit: 1a002caf7bb5
Files: 59
Total size: 461.8 KB

Directory structure:
gitextract_zgae_61m/

├── .gitignore
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── examples/
│   ├── test.yaml
│   └── tracks/
│       ├── bear.pth
│       ├── deco.pth
│       ├── fish.pth
│       ├── giraffe.pth
│       ├── human.pth
│       └── sea.pth
├── generate.py
├── gradio/
│   ├── fl2v_14B_singleGPU.py
│   ├── i2v_14B_singleGPU.py
│   ├── t2i_14B_singleGPU.py
│   ├── t2v_1.3B_singleGPU.py
│   ├── t2v_14B_singleGPU.py
│   └── vace.py
├── pyproject.toml
├── requirements.txt
├── run_example.sh
├── tests/
│   ├── README.md
│   └── test.sh
├── tools/
│   ├── get_track_from_videos.py
│   ├── plot_user_inputs.py
│   ├── trajectory_editor/
│   │   ├── app.py
│   │   └── templates/
│   │       └── index.html
│   └── visualize_trajectory.py
└── wan/
    ├── __init__.py
    ├── configs/
    │   ├── __init__.py
    │   ├── shared_config.py
    │   ├── wan_i2v_14B.py
    │   ├── wan_t2v_14B.py
    │   └── wan_t2v_1_3B.py
    ├── distributed/
    │   ├── __init__.py
    │   ├── fsdp.py
    │   └── xdit_context_parallel.py
    ├── first_last_frame2video.py
    ├── image2video.py
    ├── modules/
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── clip.py
    │   ├── model.py
    │   ├── motion_patch.py
    │   ├── t5.py
    │   ├── tokenizers.py
    │   ├── vace_model.py
    │   ├── vae.py
    │   └── xlm_roberta.py
    ├── utils/
    │   ├── __init__.py
    │   ├── fm_solvers.py
    │   ├── fm_solvers_unipc.py
    │   ├── motion.py
    │   ├── prompt_extend.py
    │   ├── qwen_vl_utils.py
    │   ├── utils.py
    │   └── vace_processor.py
    └── vace.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.*
*.py[cod]
# *.jpg
*.jpeg
# *.png
*.gif
*.bmp
*.mp4
*.mov
*.mkv
*.log
*.zip
*.pt
*.pth
*.ckpt
*.safetensors
*.json
# *.txt
*.backup
*.pkl
*.html
*.pdf
*.whl
cache
__pycache__/
storage/
samples/
samples_motion_transfer/
outputs_motion_transfer/
!.gitignore
!requirements.txt
.DS_Store
*DS_Store
google/
Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
poetry.lock

!assets/examples/*.gif
!assets/examples/*.jpg
!examples/tracks/*.pth
!assets/Teaser.mp4
!tools/
!examples/
!tools/trajectory_editor/templates/index.html
!examples/motion_transfer/0.mp4


================================================
FILE: INSTALL.md
================================================
# Installation Guide

## Install with pip

```bash
pip install .
pip install .[dev]  # Installe aussi les outils de dev
```

## Install with Poetry

Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.

To install all dependencies:

```bash
poetry install
```

### Handling `flash-attn` Installation Issues

If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.

#### No-Build-Isolation Installation (Recommended)
```bash
poetry run pip install --upgrade pip setuptools wheel
poetry run pip install flash-attn --no-build-isolation
poetry install
```

#### Install from Git (Alternative)
```bash
poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
```

---

### Running the Model

Once the installation is complete, you can run **Wan2.1** using:

```bash
poetry run python generate.py --task t2v-14B --size '1280x720' --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```

#### Test
```bash
pytest tests/
```
#### Format
```bash
black .
isort .
```


================================================
FILE: LICENSE.txt
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2025 ByteDance Ltd. and/or its affiliates.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

   # Part of source code from: https://github.com/Wan-Video/Wan2.1
   # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
   # SPDX-License-Identifier: Apache-2.0


================================================
FILE: Makefile
================================================
.PHONY: format

format:
	isort generate.py gradio wan
	yapf -i -r *.py generate.py gradio wan


================================================
FILE: README.md
================================================
# ATI: Any Trajectory Instruction for Controllable Video Generation

<div align="center">
  
[![arXiv](https://img.shields.io/badge/arXiv%20paper-2505.22944-b31b1b.svg)](https://arxiv.org/pdf/2505.22944)&nbsp;
[![project page](https://img.shields.io/badge/Project_page-ATI-green)](https://anytraj.github.io/)&nbsp;
<a href="https://huggingface.co/bytedance-research/ATI/"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
</div>


> [**ATI: Any Trajectory Instruction for Controllable Video Generation**](https://anytraj.github.io/)<br>
> [Angtian Wang](https://angtianwang.github.io/), [Haibin Huang](https://brotherhuang.github.io/), Jacob Zhiyuan Fang, [Yiding Yang](https://ihollywhy.github.io/), [Chongyang Ma](http://www.chongyangma.com/)
> <br>Intelligent Creation Team, ByteDance<br>

**Highlight: ATI motion transfer tools + demo is added. Scroll down to see the updates**

[![Watch the video](assets/thumbnail.jpg)](https://youtu.be/76jjPT0f8Hs)

This is the repo for Wan2.1 ATI (Any Trajectory Instruction for Controllable Video Generation), a trajectory-based motion control framework that unifies object, local and camera movements in video generation. This repo is based on [Wan2.1 offical implementation](https://github.com/Wan-Video/Wan2.1).

Compared with the original Wan2.1. We add the following files:
- wan/modules/motion_patch.py          | Trajectory instruction kernal module
- wan/utils/motion.py                  | Inference dataloader utils
- tools/plot_user_inputs.py            | Visualizer for user input trajectory
- tools/visualize_trajectory.py        | Visualizer for generated video
- tools/trajectory_editor/             | Interactive trajectory editor
- tools/get_track_from_videos.py       | Motion extraction tools for ATI motion transfer 
- examples/                            | Test examples
- run_example.sh                       | Easy launch script

We modified the following files:
- wan/image2video.py                   | Add blocks to load and parse trajectory  #L256
- wan/configs/__init__.py              | Config the ATI etc.
- generate.py                          | Add an entry to load yaml format inference examples

## Community Works
### ComfyUI
Thanks for Kijai develop the ComfyUI nodes for ATI:
[https://github.com/kijai/ComfyUI-WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper)

FP8 quant Huggingface Model: [https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan2_1-I2V-ATI-14B_fp8_e4m3fn.safetensors](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan2_1-I2V-ATI-14B_fp8_e4m3fn.safetensors)

### Guideline
Guideline by Benji: [https://www.youtube.com/watch?v=UM35z2L1XbI](https://www.youtube.com/watch?v=UM35z2L1XbI)

## Install

ATI requires a same environment as offical Wan 2.1. Follow the instruction of INSTALL.md (Wan2.1).

```
git clone https://github.com/bytedance/ATI.git
cd ATI
```

Install packages

```
pip install .
```

First you need to download the 14B original model of Wan2.1.

```
huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./Wan2.1-I2V-14B-480P
```

Then download ATI-Wan model from our huggingface repo.

```
huggingface-cli download bytedance-research/ATI --local-dir ./Wan2.1-ATI-14B-480P
```

Finally, copy VAE, T5 and other misc checkpoint from origin Wan2.1 folder to ATI checkpoint location

```
cp ./Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth ./Wan2.1-ATI-14B-480P/
cp ./Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth ./Wan2.1-ATI-14B-480P/
cp ./Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth ./Wan2.1-ATI-14B-480P/
cp -r ./Wan2.1-I2V-14B-480P/xlm-roberta-large ./Wan2.1-ATI-14B-480P/
cp -r ./Wan2.1-I2V-14B-480P/google ./Wan2.1-ATI-14B-480P/
```

## Run

We provide a demo sript to run ATI.

```
bash run_example.sh -p examples/test.yaml -c ./Wan2.1-ATI-14B-480P -o samples
```
where `-p` is the path to the config file, `-c` is the path to the checkpoint, `-o` is the path to the output directory, `-g` defines the number of gpus to use (if unspecificed, all avalible GPUs will be used; if `1` is given, will run on single process mode).

Once finished, you will expect to fine:
- `samples/outputs` for the raw output videos.
- `samples/images_tracks` shows the input image togather with the user specified trajectories.
- `samples/outputs_vis` shows the output videos togather with the user specified trajectories.

Expected results:


<table style="width: 100%; border-collapse: collapse; text-align: center; border: 1px solid #ccc;">
  <tr>
    <th style="text-align: center;">
      <strong>Input Image & Trajectory</strong>
    </th>
    <th style="text-align: center;">
      <strong>Generated Videos (Superimposed Trajectories)</strong>
    </th>
  </tr>

  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/00.jpg" alt="Image 0" style="height: 240px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/00.gif" alt="Image 0" style="height: 240px;">
    </td>
  </tr>

  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/01.jpg" alt="Image 1" style="height: 240px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/01.gif" alt="Image 1" style="height: 240px;">
    </td>
  </tr>

  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/02.jpg" alt="Image 2" style="height: 160px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/02.gif" alt="Image 2" style="height: 160px;">
    </td>
  </tr>

  </tr>
  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/03.jpg" alt="Image 3" style="height: 220px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/03.gif" alt="Image 3" style="height: 220px;">
    </td>
  </tr>

  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/04.jpg" alt="Image 4" style="height: 240px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/04.gif" alt="Image 4" style="height: 240px;">
    </td>
  </tr>

  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/05.jpg" alt="Image 5" style="height: 160px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/05.gif" alt="Image 5" style="height: 160px;">
    </td>
  </tr>
</table>

## Motion Transfer

![Motion Transfer](assets/MotionTransfer.jpg)
ATI can mimic a video by extracting its motion dynamics along with its first-frame image. Moreover, by leveraging powerful image-editing tools, it also enables "video-editing" capabilities.

First, extract motions from videos using the following script:
```
python3 tools/get_track_from_videos.py --source_folder examples/motion_transfer/ --save_folder samples_motion_transfer/
```

Then run ATI inference
```
bash run_example.sh -p samples_motion_transfer/test.yaml -c ./Wan2.1-ATI-14B-480P -o outputs_motion_transfer
```

Expected result

<table style="width: 100%; border-collapse: collapse; text-align: center; border: 1px solid #ccc;">
  <tr>
    <th style="text-align: center;">
      <strong>Reference Video (for Extracting Motion)</strong>
    </th>
    <th style="text-align: center;">
      <strong>First Frame Image</strong>
    </th>
    <th style="text-align: center;">
      <strong>Generated Video</strong>
    </th>
  </tr>


  <tr>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/RV_0.gif" alt="Motion Transfer Video" style="height: 160px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/RI_0.png" alt="Motion Transfer Image" style="height: 160px;">
    </td>
    <td style="text-align: center; vertical-align: middle;">
      <img src="assets/examples/G_0.gif" alt="Motion Transfer Output" style="height: 160px;">
    </td>
  </tr>
</table>


## Create You Own Trajectory

We provide an interactive tool that allow users to draw and edit trajectories on their images. 
Important note: **app.py** should only be run on **localhost**, as running it on a remote server may pose security risks.

1. First run:
```
cd tools/trajectory_editor
python3 app.py
```
then open this url [localhost:5000](http://localhost:5000/) in the browser. 

2. Get the interface shown below, then click **Choose File** to open a local image.  
![Interface Screenshot](assets/editor0.PNG)

3. Available trajectory functions:  
![Trajectory Functions](assets/editor1.PNG)

   a. **Free Trajectory**: Click and then drag with the mouse directly on the image.  
   b. **Circular (Camera Control)**:  
      - Place a circle on the image, then drag to set its size for frame 0.  
      - Place a few (3–4 recommended) track points on the circle.  
      - Drag the radius control to achieve zoom-in/zoom-out effects.  

   c. **Static Point**: A point that remains stationary over time.  

   *Note:* Pay attention to the progress bar in the box to control motion speed.  
   ![Progress Control](assets/editor2.PNG)

4. **Trajectory Editing**: Select a trajectory here, then delete, edit, or copy it. In edit mode, drag the trajectory directly on the image. The selected trajectory is highlighted by color.  
![Trajectory Editing](assets/editor3.PNG)

5. **Camera Pan Control**: Enter horizontal (X) or vertical (Y) speed (pixels per frame). Positive X moves right; negative X moves left. Positive Y moves down; negative Y moves up. Click **Add to Selected** to apply to the current trajectory, or **Add to All** to apply to all trajectories. The selected points will gain a constant pan motion on top of their existing movement.  
![Camera Pan Control](assets/editor4.PNG)

6. **Important:** After editing, click **Store Tracks** to save. Each image (not each trajectory) must be saved separately after drawing all trajectories.  
![Store Tracks](assets/editor5.PNG)

7. Once all edits are complete, locate the `videos_example` folder in the **Trajectory Editor**.


## Citation
Please cite our paper if you find our work useful:
```
@article{wang2025ati,
  title={{ATI}: Any Trajectory Instruction for Controllable Video Generation},
  author={Wang, Angtian and Huang, Haibin and Fang, Zhiyuan and Yang, Yiding, and Ma, Chongyang}
  journal={arXiv preprint},
  volume={arXiv:2505.22944},
  year={2025}
}
```



================================================
FILE: examples/test.yaml
================================================
- image: examples/images/fish.jpg
  text: "A tranquil koi pond edged by mossy stone, with lily pads drifting on the surface and several orange\u2011and\u2011white koi fish gliding beneath."
  track: examples/tracks/fish.pth
- image: examples/images/human.jpg  
  text: "An human facing the camera in an cyberbank style dress."
  track: examples/tracks/human.pth
- image: examples/images/sea.png
  text: Surreal scene of a colossal ocean wave curling inside an opulent vaulted gallery, two tiny surfers riding its emerald face.
  track: examples/tracks/sea.pth
- image: examples/images/deco.png
  text: A gleaming gold necklace with elongated links gently frames a U-shaped pendant encrusted with delicate, shimmering stones. The pendant’s bold, modern design contrasts beautifully with the fine details of its sparkling accents. Set against a gradient background of tranquil blues, the piece exudes both sophistication and understated luxury.
  track: examples/tracks/deco.pth
- image: examples/images/giraffe.jpg
  text: "A close-up portrait of a giraffe’s head and long neck against a soft-focus woodland backdrop."  
  track: examples/tracks/giraffe.pth
- image: examples/images/bear.jpg
  text: "A brown bear lying in the shade beside a rock, resting on a bed of grass."
  track: examples/tracks/bear.pth

================================================
FILE: generate.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0

from wan.utils.motion import get_tracks_inference
from wan.utils.utils import cache_video, cache_image, str2bool
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
import wan
from PIL import Image
import torch.distributed as dist
import torch
import random
import argparse
import logging
import os
import sys
import warnings
import yaml
from datetime import datetime

warnings.filterwarnings('ignore')


def _validate_args(args):
    # Basic check
    assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
    assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"

    # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
    if args.sample_steps is None:
        args.sample_steps = 40

    if args.sample_shift is None:
        args.sample_shift = 5.0
        # if args.size in ["832*480", "480*832"]:
        #     args.sample_shift = 3.0

    # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
    if args.frame_num is None:
        args.frame_num = 1 if "t2i" in args.task else 81

    # T2I frame_num check
    if "t2i" in args.task:
        assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"

    args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
        0, sys.maxsize)
    # Size check
    assert args.size in SUPPORTED_SIZES[
        args.
        task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"


def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a image or video from a text prompt or image using Wan"
    )
    parser.add_argument(
        "--task",
        type=str,
        default="ati-14B",
        choices=list(WAN_CONFIGS.keys()),
        help="The task to run.")
    parser.add_argument(
        "--size",
        type=str,
        default="832*480",
        choices=list(SIZE_CONFIGS.keys()),
        help="The area (width*height) of the generated video."
    )
    parser.add_argument(
        "--frame_num",
        type=int,
        default=None,
        help="How many frames to sample from a image or video. The number should be 4n+1"
    )
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default=None,
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--offload_model",
        type=str2bool,
        default=None,
        help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
    )
    parser.add_argument(
        "--ulysses_size",
        type=int,
        default=1,
        help="The size of the ulysses parallelism in DiT.")
    parser.add_argument(
        "--ring_size",
        type=int,
        default=1,
        help="The size of the ring attention parallelism in DiT.")
    parser.add_argument(
        "--t5_fsdp",
        action="store_true",
        default=False,
        help="Whether to use FSDP for T5.")
    parser.add_argument(
        "--t5_cpu",
        action="store_true",
        default=False,
        help="Whether to place T5 model on CPU.")
    parser.add_argument(
        "--dit_fsdp",
        action="store_true",
        default=False,
        help="Whether to use FSDP for DiT.")
    parser.add_argument(
        "--save_file",
        type=str,
        default=None,
        help="The file to save the generated image or video to.")
    parser.add_argument(
        "--src_video",
        type=str,
        default=None,
        help="The file of the source video. Default None.")
    parser.add_argument(
        "--src_mask",
        type=str,
        default=None,
        help="The file of the source mask. Default None.")
    parser.add_argument(
        "--src_ref_images",
        type=str,
        default=None,
        help="The file list of the source reference images. Separated by ','. Default None."
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default=None,
        help="The prompt to generate the image or video from.")
    parser.add_argument(
        "--use_prompt_extend",
        action="store_true",
        default=False,
        help="Whether to use prompt extend.")
    parser.add_argument(
        "--prompt_extend_method",
        type=str,
        default="local_qwen",
        choices=["dashscope", "local_qwen"],
        help="The prompt extend method to use.")
    parser.add_argument(
        "--prompt_extend_model",
        type=str,
        default=None,
        help="The prompt extend model to use.")
    parser.add_argument(
        "--prompt_extend_target_lang",
        type=str,
        default="zh",
        choices=["zh", "en"],
        help="The target language of prompt extend.")
    parser.add_argument(
        "--base_seed",
        type=int,
        default=-1,
        help="The seed to use for generating the image or video.")
    parser.add_argument(
        "--image",
        type=str,
        default=None,
        help="[image to video] The image to generate the video from.")
    parser.add_argument(
        "--track",
        type=str,
        default=None,
        help="The stored point trajectory to generate the video.")
    parser.add_argument(
        "--first_frame",
        type=str,
        default=None,
        help="[first-last frame to video] The image (first frame) to generate the video from."
    )
    parser.add_argument(
        "--last_frame",
        type=str,
        default=None,
        help="[first-last frame to video] The image (last frame) to generate the video from."
    )
    parser.add_argument(
        "--sample_solver",
        type=str,
        default='unipc',
        choices=['unipc', 'dpm++'],
        help="The solver used to sample.")
    parser.add_argument(
        "--sample_steps", type=int, default=None, help="The sampling steps.")
    parser.add_argument(
        "--sample_shift",
        type=float,
        default=None,
        help="Sampling shift factor for flow matching schedulers.")
    parser.add_argument(
        "--sample_guide_scale",
        type=float,
        default=5.0,
        help="Classifier free guidance scale.")

    args = parser.parse_args()

    _validate_args(args)

    return args


def _init_logging(rank):
    # logging
    if rank == 0:
        # set format
        logging.basicConfig(
            level=logging.INFO,
            format="[%(asctime)s] %(levelname)s: %(message)s",
            handlers=[logging.StreamHandler(stream=sys.stdout)])
    else:
        logging.basicConfig(level=logging.ERROR)


def generate(args):
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    device = local_rank
    _init_logging(rank)

    if args.offload_model is None:
        args.offload_model = False if world_size > 1 else True
        logging.info(
            f"offload_model is not specified, set to {args.offload_model}.")
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            rank=rank,
            world_size=world_size)
    else:
        assert not (
            args.t5_fsdp or args.dit_fsdp
        ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
        assert not (
            args.ulysses_size > 1 or args.ring_size > 1
        ), f"context parallel are not supported in non-distributed environments."

    if args.ulysses_size > 1 or args.ring_size > 1:
        assert args.ulysses_size * \
            args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
        from xfuser.core.distributed import (
            init_distributed_environment,
            initialize_model_parallel,
        )
        init_distributed_environment(
            rank=dist.get_rank(), world_size=dist.get_world_size())

        initialize_model_parallel(
            sequence_parallel_degree=dist.get_world_size(),
            ring_degree=args.ring_size,
            ulysses_degree=args.ulysses_size,
        )

    if args.use_prompt_extend:
        if args.prompt_extend_method == "dashscope":
            prompt_expander = DashScopePromptExpander(
                model_name=args.prompt_extend_model,
                is_vl=True)
        elif args.prompt_extend_method == "local_qwen":
            prompt_expander = QwenPromptExpander(
                model_name=args.prompt_extend_model,
                is_vl=True,
                device=rank)
        else:
            raise NotImplementedError(
                f"Unsupport prompt_extend_method: {args.prompt_extend_method}")

    cfg = WAN_CONFIGS[args.task]
    if args.ulysses_size > 1:
        assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."

    logging.info(f"Generation job args: {args}")
    logging.info(f"Generation model config: {cfg}")

    if dist.is_initialized():
        base_seed = [args.base_seed] if rank == 0 else [None]
        dist.broadcast_object_list(base_seed, src=0)
        args.base_seed = base_seed[0]

    if args.prompt.endswith('.yaml'):
        inputs_ = []

        fl_list = yaml.safe_load(open(args.prompt))
        for line in fl_list:
            inputs_.append(
                (line['image'], line['text'].strip(), line['track']))
    else:
        inputs_ = [(args.image, args.prompt, args.track)]

    logging.info("Creating WanATI pipeline.")
    wan_ati = wan.WanATI(
        config=cfg,
        checkpoint_dir=args.ckpt_dir,
        device_id=device,
        rank=rank,
        t5_fsdp=args.t5_fsdp,
        dit_fsdp=args.dit_fsdp,
        use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
        t5_cpu=args.t5_cpu,
    )

    for ii, input_ in enumerate(inputs_):
        if args.save_file is None:
            formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")

            if args.prompt.endswith(".yaml"):
                formatted_prompt = f"{ii:02d}"
            else:
                formatted_prompt = args.prompt.replace(" ", "_").replace("/",
                                                                         "_")[:50]
            suffix = '.mp4'
            args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix

        if '%' in args.save_file:
            save_file = args.save_file % ii
        else:
            save_file = args.save_file

        if os.path.exists(save_file):
            logging.info(f"File {save_file} already exists, skipping.")
            continue

        image, prompt, tracks = input_
        logging.info(f"Input prompt: {prompt}")
        logging.info(f"Input image: {image}")

        img = Image.open(image).convert("RGB")

        width, height = img.size
        tracks = get_tracks_inference(tracks, height, width)

        if args.use_prompt_extend:
            logging.info("Extending prompt ...")
            if rank == 0:
                prompt_output = prompt_expander(
                    prompt,
                    tar_lang=args.prompt_extend_target_lang,
                    image=img,
                    seed=args.base_seed)
                if prompt_output.status == False:
                    logging.info(
                        f"Extending prompt failed: {prompt_output.message}")
                    logging.info("Falling back to original prompt.")
                    input_prompt = prompt
                else:
                    input_prompt = prompt_output.prompt
                input_prompt = [input_prompt]
            else:
                input_prompt = [None]
            if dist.is_initialized():
                dist.broadcast_object_list(input_prompt, src=0)
            prompt = input_prompt[0]
            logging.info(f"Extended prompt: {prompt}")

        logging.info("Generating video ...")
        video = wan_ati.generate(
            prompt,
            img,
            tracks,
            max_area=MAX_AREA_CONFIGS[args.size],
            frame_num=args.frame_num,
            shift=args.sample_shift,
            sample_solver=args.sample_solver,
            sampling_steps=args.sample_steps,
            guide_scale=args.sample_guide_scale,
            seed=args.base_seed,
            offload_model=args.offload_model)

        if rank == 0:
            logging.info(f"Saving generated video to {save_file}")
            cache_video(
                tensor=video[None],
                save_file=save_file,
                fps=cfg.sample_fps,
                nrow=1,
                normalize=True,
                value_range=(-1, 1))
    logging.info("Finished.")


if __name__ == "__main__":
    args = _parse_args()
    generate(args)


================================================
FILE: gradio/fl2v_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os
import os.path as osp
import sys
import warnings

import gradio as gr

warnings.filterwarnings('ignore')

# Model
sys.path.insert(
    0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video

# Global Var
prompt_expander = None
wan_flf2v_720P = None


# Button Func
def load_model(value):
    global wan_flf2v_720P

    if value == '------':
        print("No model loaded")
        return '------'

    if value == '720P':
        if args.ckpt_dir_720p is None:
            print("Please specify the checkpoint directory for 720P model")
            return '------'
        if wan_flf2v_720P is not None:
            pass
        else:
            gc.collect()

            print("load 14B-720P flf2v model...", end='', flush=True)
            cfg = WAN_CONFIGS['flf2v-14B']
            wan_flf2v_720P = wan.WanFLF2V(
                config=cfg,
                checkpoint_dir=args.ckpt_dir_720p,
                device_id=0,
                rank=0,
                t5_fsdp=False,
                dit_fsdp=False,
                use_usp=False,
            )
            print("done", flush=True)
            return '720P'
    return value


def prompt_enc(prompt, img_first, img_last, tar_lang):
    print('prompt extend...')
    if img_first is None or img_last is None:
        print('Please upload the first and last frames')
        return prompt
    global prompt_expander
    prompt_output = prompt_expander(
        prompt, image=[img_first, img_last], tar_lang=tar_lang.lower())
    if prompt_output.status == False:
        return prompt
    else:
        return prompt_output.prompt


def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
                     resolution, sd_steps, guide_scale, shift_scale, seed,
                     n_prompt):

    if resolution == '------':
        print(
            'Please specify the resolution ckpt dir or specify the resolution')
        return None

    else:
        if resolution == '720P':
            global wan_flf2v_720P
            video = wan_flf2v_720P.generate(
                flf2vid_prompt,
                flf2vid_image_first,
                flf2vid_image_last,
                max_area=MAX_AREA_CONFIGS['720*1280'],
                shift=shift_scale,
                sampling_steps=sd_steps,
                guide_scale=guide_scale,
                n_prompt=n_prompt,
                seed=seed,
                offload_model=True)
            pass
        else:
            print('Sorry, currently only 720P is supported.')
            return None

        cache_video(
            tensor=video[None],
            save_file="example.mp4",
            fps=16,
            nrow=1,
            normalize=True,
            value_range=(-1, 1))

        return "example.mp4"


# Interface
def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("""
                    <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                        Wan2.1 (FLF2V-14B)
                    </div>
                    <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
                        Wan: Open and Advanced Large-Scale Video Generative Models.
                    </div>
                    """)

        with gr.Row():
            with gr.Column():
                resolution = gr.Dropdown(
                    label='Resolution',
                    choices=['------', '720P'],
                    value='------')
                flf2vid_image_first = gr.Image(
                    type="pil",
                    label="Upload First Frame",
                    elem_id="image_upload",
                )
                flf2vid_image_last = gr.Image(
                    type="pil",
                    label="Upload Last Frame",
                    elem_id="image_upload",
                )
                flf2vid_prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe the video you want to generate",
                )
                tar_lang = gr.Radio(
                    choices=["ZH", "EN"],
                    label="Target language of prompt enhance",
                    value="ZH")
                run_p_button = gr.Button(value="Prompt Enhance")

                with gr.Accordion("Advanced Options", open=True):
                    with gr.Row():
                        sd_steps = gr.Slider(
                            label="Diffusion steps",
                            minimum=1,
                            maximum=1000,
                            value=50,
                            step=1)
                        guide_scale = gr.Slider(
                            label="Guide scale",
                            minimum=0,
                            maximum=20,
                            value=5.0,
                            step=1)
                    with gr.Row():
                        shift_scale = gr.Slider(
                            label="Shift scale",
                            minimum=0,
                            maximum=20,
                            value=5.0,
                            step=1)
                        seed = gr.Slider(
                            label="Seed",
                            minimum=-1,
                            maximum=2147483647,
                            step=1,
                            value=-1)
                    n_prompt = gr.Textbox(
                        label="Negative Prompt",
                        placeholder="Describe the negative prompt you want to add"
                    )

                run_flf2v_button = gr.Button("Generate Video")

            with gr.Column():
                result_gallery = gr.Video(
                    label='Generated Video', interactive=False, height=600)

        resolution.input(
            fn=load_model, inputs=[resolution], outputs=[resolution])

        run_p_button.click(
            fn=prompt_enc,
            inputs=[
                flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
                tar_lang
            ],
            outputs=[flf2vid_prompt])

        run_flf2v_button.click(
            fn=flf2v_generation,
            inputs=[
                flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
                resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
            ],
            outputs=[result_gallery],
        )

    return demo


# Main
def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a video from a text prompt or image using Gradio")
    parser.add_argument(
        "--ckpt_dir_720p",
        type=str,
        default=None,
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--prompt_extend_method",
        type=str,
        default="local_qwen",
        choices=["dashscope", "local_qwen"],
        help="The prompt extend method to use.")
    parser.add_argument(
        "--prompt_extend_model",
        type=str,
        default=None,
        help="The prompt extend model to use.")

    args = parser.parse_args()
    assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory."

    return args


if __name__ == '__main__':
    args = _parse_args()

    print("Step1: Init prompt_expander...", end='', flush=True)
    if args.prompt_extend_method == "dashscope":
        prompt_expander = DashScopePromptExpander(
            model_name=args.prompt_extend_model, is_vl=True)
    elif args.prompt_extend_method == "local_qwen":
        prompt_expander = QwenPromptExpander(
            model_name=args.prompt_extend_model, is_vl=True, device=0)
    else:
        raise NotImplementedError(
            f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
    print("done", flush=True)

    demo = gradio_interface()
    demo.launch(server_name="0.0.0.0", share=False, server_port=7860)


================================================
FILE: gradio/i2v_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os
import os.path as osp
import sys
import warnings

import gradio as gr

warnings.filterwarnings('ignore')

# Model
sys.path.insert(
    0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video

# Global Var
prompt_expander = None
wan_i2v_480P = None
wan_i2v_720P = None


# Button Func
def load_model(value):
    global wan_i2v_480P, wan_i2v_720P

    if value == '------':
        print("No model loaded")
        return '------'

    if value == '720P':
        if args.ckpt_dir_720p is None:
            print("Please specify the checkpoint directory for 720P model")
            return '------'
        if wan_i2v_720P is not None:
            pass
        else:
            del wan_i2v_480P
            gc.collect()
            wan_i2v_480P = None

            print("load 14B-720P i2v model...", end='', flush=True)
            cfg = WAN_CONFIGS['i2v-14B']
            wan_i2v_720P = wan.WanI2V(
                config=cfg,
                checkpoint_dir=args.ckpt_dir_720p,
                device_id=0,
                rank=0,
                t5_fsdp=False,
                dit_fsdp=False,
                use_usp=False,
            )
            print("done", flush=True)
            return '720P'

    if value == '480P':
        if args.ckpt_dir_480p is None:
            print("Please specify the checkpoint directory for 480P model")
            return '------'
        if wan_i2v_480P is not None:
            pass
        else:
            del wan_i2v_720P
            gc.collect()
            wan_i2v_720P = None

            print("load 14B-480P i2v model...", end='', flush=True)
            cfg = WAN_CONFIGS['i2v-14B']
            wan_i2v_480P = wan.WanI2V(
                config=cfg,
                checkpoint_dir=args.ckpt_dir_480p,
                device_id=0,
                rank=0,
                t5_fsdp=False,
                dit_fsdp=False,
                use_usp=False,
            )
            print("done", flush=True)
            return '480P'
    return value


def prompt_enc(prompt, img, tar_lang):
    print('prompt extend...')
    if img is None:
        print('Please upload an image')
        return prompt
    global prompt_expander
    prompt_output = prompt_expander(
        prompt, image=img, tar_lang=tar_lang.lower())
    if prompt_output.status == False:
        return prompt
    else:
        return prompt_output.prompt


def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
                   guide_scale, shift_scale, seed, n_prompt):
    # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")

    if resolution == '------':
        print(
            'Please specify at least one resolution ckpt dir or specify the resolution'
        )
        return None

    else:
        if resolution == '720P':
            global wan_i2v_720P
            video = wan_i2v_720P.generate(
                img2vid_prompt,
                img2vid_image,
                max_area=MAX_AREA_CONFIGS['720*1280'],
                shift=shift_scale,
                sampling_steps=sd_steps,
                guide_scale=guide_scale,
                n_prompt=n_prompt,
                seed=seed,
                offload_model=True)
        else:
            global wan_i2v_480P
            video = wan_i2v_480P.generate(
                img2vid_prompt,
                img2vid_image,
                max_area=MAX_AREA_CONFIGS['480*832'],
                shift=shift_scale,
                sampling_steps=sd_steps,
                guide_scale=guide_scale,
                n_prompt=n_prompt,
                seed=seed,
                offload_model=True)

        cache_video(
            tensor=video[None],
            save_file="example.mp4",
            fps=16,
            nrow=1,
            normalize=True,
            value_range=(-1, 1))

        return "example.mp4"


# Interface
def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("""
                    <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                        Wan2.1 (I2V-14B)
                    </div>
                    <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
                        Wan: Open and Advanced Large-Scale Video Generative Models.
                    </div>
                    """)

        with gr.Row():
            with gr.Column():
                resolution = gr.Dropdown(
                    label='Resolution',
                    choices=['------', '720P', '480P'],
                    value='------')

                img2vid_image = gr.Image(
                    type="pil",
                    label="Upload Input Image",
                    elem_id="image_upload",
                )
                img2vid_prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe the video you want to generate",
                )
                tar_lang = gr.Radio(
                    choices=["ZH", "EN"],
                    label="Target language of prompt enhance",
                    value="ZH")
                run_p_button = gr.Button(value="Prompt Enhance")

                with gr.Accordion("Advanced Options", open=True):
                    with gr.Row():
                        sd_steps = gr.Slider(
                            label="Diffusion steps",
                            minimum=1,
                            maximum=1000,
                            value=50,
                            step=1)
                        guide_scale = gr.Slider(
                            label="Guide scale",
                            minimum=0,
                            maximum=20,
                            value=5.0,
                            step=1)
                    with gr.Row():
                        shift_scale = gr.Slider(
                            label="Shift scale",
                            minimum=0,
                            maximum=10,
                            value=5.0,
                            step=1)
                        seed = gr.Slider(
                            label="Seed",
                            minimum=-1,
                            maximum=2147483647,
                            step=1,
                            value=-1)
                    n_prompt = gr.Textbox(
                        label="Negative Prompt",
                        placeholder="Describe the negative prompt you want to add"
                    )

                run_i2v_button = gr.Button("Generate Video")

            with gr.Column():
                result_gallery = gr.Video(
                    label='Generated Video', interactive=False, height=600)

        resolution.input(
            fn=load_model, inputs=[resolution], outputs=[resolution])

        run_p_button.click(
            fn=prompt_enc,
            inputs=[img2vid_prompt, img2vid_image, tar_lang],
            outputs=[img2vid_prompt])

        run_i2v_button.click(
            fn=i2v_generation,
            inputs=[
                img2vid_prompt, img2vid_image, resolution, sd_steps,
                guide_scale, shift_scale, seed, n_prompt
            ],
            outputs=[result_gallery],
        )

    return demo


# Main
def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a video from a text prompt or image using Gradio")
    parser.add_argument(
        "--ckpt_dir_720p",
        type=str,
        default=None,
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--ckpt_dir_480p",
        type=str,
        default=None,
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--prompt_extend_method",
        type=str,
        default="local_qwen",
        choices=["dashscope", "local_qwen"],
        help="The prompt extend method to use.")
    parser.add_argument(
        "--prompt_extend_model",
        type=str,
        default=None,
        help="The prompt extend model to use.")

    args = parser.parse_args()
    assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."

    return args


if __name__ == '__main__':
    args = _parse_args()

    print("Step1: Init prompt_expander...", end='', flush=True)
    if args.prompt_extend_method == "dashscope":
        prompt_expander = DashScopePromptExpander(
            model_name=args.prompt_extend_model, is_vl=True)
    elif args.prompt_extend_method == "local_qwen":
        prompt_expander = QwenPromptExpander(
            model_name=args.prompt_extend_model, is_vl=True, device=0)
    else:
        raise NotImplementedError(
            f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
    print("done", flush=True)

    demo = gradio_interface()
    demo.launch(server_name="0.0.0.0", share=False, server_port=7860)


================================================
FILE: gradio/t2i_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import warnings

import gradio as gr

warnings.filterwarnings('ignore')

# Model
sys.path.insert(
    0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image

# Global Var
prompt_expander = None
wan_t2i = None


# Button Func
def prompt_enc(prompt, tar_lang):
    global prompt_expander
    prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
    if prompt_output.status == False:
        return prompt
    else:
        return prompt_output.prompt


def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
                   shift_scale, seed, n_prompt):
    global wan_t2i
    # print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")

    W = int(resolution.split("*")[0])
    H = int(resolution.split("*")[1])
    video = wan_t2i.generate(
        txt2img_prompt,
        size=(W, H),
        frame_num=1,
        shift=shift_scale,
        sampling_steps=sd_steps,
        guide_scale=guide_scale,
        n_prompt=n_prompt,
        seed=seed,
        offload_model=True)

    cache_image(
        tensor=video.squeeze(1)[None],
        save_file="example.png",
        nrow=1,
        normalize=True,
        value_range=(-1, 1))

    return "example.png"


# Interface
def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("""
                    <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                        Wan2.1 (T2I-14B)
                    </div>
                    <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
                        Wan: Open and Advanced Large-Scale Video Generative Models.
                    </div>
                    """)

        with gr.Row():
            with gr.Column():
                txt2img_prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe the image you want to generate",
                )
                tar_lang = gr.Radio(
                    choices=["ZH", "EN"],
                    label="Target language of prompt enhance",
                    value="ZH")
                run_p_button = gr.Button(value="Prompt Enhance")

                with gr.Accordion("Advanced Options", open=True):
                    resolution = gr.Dropdown(
                        label='Resolution(Width*Height)',
                        choices=[
                            '720*1280', '1280*720', '960*960', '1088*832',
                            '832*1088', '480*832', '832*480', '624*624',
                            '704*544', '544*704'
                        ],
                        value='720*1280')

                    with gr.Row():
                        sd_steps = gr.Slider(
                            label="Diffusion steps",
                            minimum=1,
                            maximum=1000,
                            value=50,
                            step=1)
                        guide_scale = gr.Slider(
                            label="Guide scale",
                            minimum=0,
                            maximum=20,
                            value=5.0,
                            step=1)
                    with gr.Row():
                        shift_scale = gr.Slider(
                            label="Shift scale",
                            minimum=0,
                            maximum=10,
                            value=5.0,
                            step=1)
                        seed = gr.Slider(
                            label="Seed",
                            minimum=-1,
                            maximum=2147483647,
                            step=1,
                            value=-1)
                    n_prompt = gr.Textbox(
                        label="Negative Prompt",
                        placeholder="Describe the negative prompt you want to add"
                    )

                run_t2i_button = gr.Button("Generate Image")

            with gr.Column():
                result_gallery = gr.Image(
                    label='Generated Image', interactive=False, height=600)

        run_p_button.click(
            fn=prompt_enc,
            inputs=[txt2img_prompt, tar_lang],
            outputs=[txt2img_prompt])

        run_t2i_button.click(
            fn=t2i_generation,
            inputs=[
                txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
                seed, n_prompt
            ],
            outputs=[result_gallery],
        )

    return demo


# Main
def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a image from a text prompt or image using Gradio")
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default="cache",
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--prompt_extend_method",
        type=str,
        default="local_qwen",
        choices=["dashscope", "local_qwen"],
        help="The prompt extend method to use.")
    parser.add_argument(
        "--prompt_extend_model",
        type=str,
        default=None,
        help="The prompt extend model to use.")

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = _parse_args()

    print("Step1: Init prompt_expander...", end='', flush=True)
    if args.prompt_extend_method == "dashscope":
        prompt_expander = DashScopePromptExpander(
            model_name=args.prompt_extend_model, is_vl=False)
    elif args.prompt_extend_method == "local_qwen":
        prompt_expander = QwenPromptExpander(
            model_name=args.prompt_extend_model, is_vl=False, device=0)
    else:
        raise NotImplementedError(
            f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
    print("done", flush=True)

    print("Step2: Init 14B t2i model...", end='', flush=True)
    cfg = WAN_CONFIGS['t2i-14B']
    wan_t2i = wan.WanT2V(
        config=cfg,
        checkpoint_dir=args.ckpt_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
    )
    print("done", flush=True)

    demo = gradio_interface()
    demo.launch(server_name="0.0.0.0", share=False, server_port=7860)


================================================
FILE: gradio/t2v_1.3B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import warnings

import gradio as gr

warnings.filterwarnings('ignore')

# Model
sys.path.insert(
    0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video

# Global Var
prompt_expander = None
wan_t2v = None


# Button Func
def prompt_enc(prompt, tar_lang):
    global prompt_expander
    prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
    if prompt_output.status == False:
        return prompt
    else:
        return prompt_output.prompt


def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
                   shift_scale, seed, n_prompt):
    global wan_t2v
    # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")

    W = int(resolution.split("*")[0])
    H = int(resolution.split("*")[1])
    video = wan_t2v.generate(
        txt2vid_prompt,
        size=(W, H),
        shift=shift_scale,
        sampling_steps=sd_steps,
        guide_scale=guide_scale,
        n_prompt=n_prompt,
        seed=seed,
        offload_model=True)

    cache_video(
        tensor=video[None],
        save_file="example.mp4",
        fps=16,
        nrow=1,
        normalize=True,
        value_range=(-1, 1))

    return "example.mp4"


# Interface
def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("""
                    <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                        Wan2.1 (T2V-1.3B)
                    </div>
                    <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
                        Wan: Open and Advanced Large-Scale Video Generative Models.
                    </div>
                    """)

        with gr.Row():
            with gr.Column():
                txt2vid_prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe the video you want to generate",
                )
                tar_lang = gr.Radio(
                    choices=["ZH", "EN"],
                    label="Target language of prompt enhance",
                    value="ZH")
                run_p_button = gr.Button(value="Prompt Enhance")

                with gr.Accordion("Advanced Options", open=True):
                    resolution = gr.Dropdown(
                        label='Resolution(Width*Height)',
                        choices=[
                            '480*832',
                            '832*480',
                            '624*624',
                            '704*544',
                            '544*704',
                        ],
                        value='480*832')

                    with gr.Row():
                        sd_steps = gr.Slider(
                            label="Diffusion steps",
                            minimum=1,
                            maximum=1000,
                            value=50,
                            step=1)
                        guide_scale = gr.Slider(
                            label="Guide scale",
                            minimum=0,
                            maximum=20,
                            value=6.0,
                            step=1)
                    with gr.Row():
                        shift_scale = gr.Slider(
                            label="Shift scale",
                            minimum=0,
                            maximum=20,
                            value=8.0,
                            step=1)
                        seed = gr.Slider(
                            label="Seed",
                            minimum=-1,
                            maximum=2147483647,
                            step=1,
                            value=-1)
                    n_prompt = gr.Textbox(
                        label="Negative Prompt",
                        placeholder="Describe the negative prompt you want to add"
                    )

                run_t2v_button = gr.Button("Generate Video")

            with gr.Column():
                result_gallery = gr.Video(
                    label='Generated Video', interactive=False, height=600)

        run_p_button.click(
            fn=prompt_enc,
            inputs=[txt2vid_prompt, tar_lang],
            outputs=[txt2vid_prompt])

        run_t2v_button.click(
            fn=t2v_generation,
            inputs=[
                txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
                seed, n_prompt
            ],
            outputs=[result_gallery],
        )

    return demo


# Main
def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a video from a text prompt or image using Gradio")
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default="cache",
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--prompt_extend_method",
        type=str,
        default="local_qwen",
        choices=["dashscope", "local_qwen"],
        help="The prompt extend method to use.")
    parser.add_argument(
        "--prompt_extend_model",
        type=str,
        default=None,
        help="The prompt extend model to use.")

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = _parse_args()

    print("Step1: Init prompt_expander...", end='', flush=True)
    if args.prompt_extend_method == "dashscope":
        prompt_expander = DashScopePromptExpander(
            model_name=args.prompt_extend_model, is_vl=False)
    elif args.prompt_extend_method == "local_qwen":
        prompt_expander = QwenPromptExpander(
            model_name=args.prompt_extend_model, is_vl=False, device=0)
    else:
        raise NotImplementedError(
            f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
    print("done", flush=True)

    print("Step2: Init 1.3B t2v model...", end='', flush=True)
    cfg = WAN_CONFIGS['t2v-1.3B']
    wan_t2v = wan.WanT2V(
        config=cfg,
        checkpoint_dir=args.ckpt_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
    )
    print("done", flush=True)

    demo = gradio_interface()
    demo.launch(server_name="0.0.0.0", share=False, server_port=7860)


================================================
FILE: gradio/t2v_14B_singleGPU.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
import sys
import warnings

import gradio as gr

warnings.filterwarnings('ignore')

# Model
sys.path.insert(
    0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video

# Global Var
prompt_expander = None
wan_t2v = None


# Button Func
def prompt_enc(prompt, tar_lang):
    global prompt_expander
    prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
    if prompt_output.status == False:
        return prompt
    else:
        return prompt_output.prompt


def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
                   shift_scale, seed, n_prompt):
    global wan_t2v
    # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")

    W = int(resolution.split("*")[0])
    H = int(resolution.split("*")[1])
    video = wan_t2v.generate(
        txt2vid_prompt,
        size=(W, H),
        shift=shift_scale,
        sampling_steps=sd_steps,
        guide_scale=guide_scale,
        n_prompt=n_prompt,
        seed=seed,
        offload_model=True)

    cache_video(
        tensor=video[None],
        save_file="example.mp4",
        fps=16,
        nrow=1,
        normalize=True,
        value_range=(-1, 1))

    return "example.mp4"


# Interface
def gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("""
                    <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                        Wan2.1 (T2V-14B)
                    </div>
                    <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
                        Wan: Open and Advanced Large-Scale Video Generative Models.
                    </div>
                    """)

        with gr.Row():
            with gr.Column():
                txt2vid_prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe the video you want to generate",
                )
                tar_lang = gr.Radio(
                    choices=["ZH", "EN"],
                    label="Target language of prompt enhance",
                    value="ZH")
                run_p_button = gr.Button(value="Prompt Enhance")

                with gr.Accordion("Advanced Options", open=True):
                    resolution = gr.Dropdown(
                        label='Resolution(Width*Height)',
                        choices=[
                            '720*1280', '1280*720', '960*960', '1088*832',
                            '832*1088', '480*832', '832*480', '624*624',
                            '704*544', '544*704'
                        ],
                        value='720*1280')

                    with gr.Row():
                        sd_steps = gr.Slider(
                            label="Diffusion steps",
                            minimum=1,
                            maximum=1000,
                            value=50,
                            step=1)
                        guide_scale = gr.Slider(
                            label="Guide scale",
                            minimum=0,
                            maximum=20,
                            value=5.0,
                            step=1)
                    with gr.Row():
                        shift_scale = gr.Slider(
                            label="Shift scale",
                            minimum=0,
                            maximum=10,
                            value=5.0,
                            step=1)
                        seed = gr.Slider(
                            label="Seed",
                            minimum=-1,
                            maximum=2147483647,
                            step=1,
                            value=-1)
                    n_prompt = gr.Textbox(
                        label="Negative Prompt",
                        placeholder="Describe the negative prompt you want to add"
                    )

                run_t2v_button = gr.Button("Generate Video")

            with gr.Column():
                result_gallery = gr.Video(
                    label='Generated Video', interactive=False, height=600)

        run_p_button.click(
            fn=prompt_enc,
            inputs=[txt2vid_prompt, tar_lang],
            outputs=[txt2vid_prompt])

        run_t2v_button.click(
            fn=t2v_generation,
            inputs=[
                txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
                seed, n_prompt
            ],
            outputs=[result_gallery],
        )

    return demo


# Main
def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a video from a text prompt or image using Gradio")
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default="cache",
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--prompt_extend_method",
        type=str,
        default="local_qwen",
        choices=["dashscope", "local_qwen"],
        help="The prompt extend method to use.")
    parser.add_argument(
        "--prompt_extend_model",
        type=str,
        default=None,
        help="The prompt extend model to use.")

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = _parse_args()

    print("Step1: Init prompt_expander...", end='', flush=True)
    if args.prompt_extend_method == "dashscope":
        prompt_expander = DashScopePromptExpander(
            model_name=args.prompt_extend_model, is_vl=False)
    elif args.prompt_extend_method == "local_qwen":
        prompt_expander = QwenPromptExpander(
            model_name=args.prompt_extend_model, is_vl=False, device=0)
    else:
        raise NotImplementedError(
            f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
    print("done", flush=True)

    print("Step2: Init 14B t2v model...", end='', flush=True)
    cfg = WAN_CONFIGS['t2v-14B']
    wan_t2v = wan.WanT2V(
        config=cfg,
        checkpoint_dir=args.ckpt_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
    )
    print("done", flush=True)

    demo = gradio_interface()
    demo.launch(server_name="0.0.0.0", share=False, server_port=7860)


================================================
FILE: gradio/vace.py
================================================
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import argparse
import datetime
import os
import sys

import imageio
import numpy as np
import torch

import gradio as gr

sys.path.insert(
    0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS


class FixedSizeQueue:

    def __init__(self, max_size):
        self.max_size = max_size
        self.queue = []

    def add(self, item):
        self.queue.insert(0, item)
        if len(self.queue) > self.max_size:
            self.queue.pop()

    def get(self):
        return self.queue

    def __repr__(self):
        return str(self.queue)


class VACEInference:

    def __init__(self,
                 cfg,
                 skip_load=False,
                 gallery_share=True,
                 gallery_share_limit=5):
        self.cfg = cfg
        self.save_dir = cfg.save_dir
        self.gallery_share = gallery_share
        self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
        if not skip_load:
            if not args.mp:
                self.pipe = WanVace(
                    config=WAN_CONFIGS[cfg.model_name],
                    checkpoint_dir=cfg.ckpt_dir,
                    device_id=0,
                    rank=0,
                    t5_fsdp=False,
                    dit_fsdp=False,
                    use_usp=False,
                )
            else:
                self.pipe = WanVaceMP(
                    config=WAN_CONFIGS[cfg.model_name],
                    checkpoint_dir=cfg.ckpt_dir,
                    use_usp=True,
                    ulysses_size=cfg.ulysses_size,
                    ring_size=cfg.ring_size)

    def create_ui(self, *args, **kwargs):
        gr.Markdown("""
                    <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
                        <a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
                    </div>
                    """)
        with gr.Row(variant='panel', equal_height=True):
            with gr.Column(scale=1, min_width=0):
                self.src_video = gr.Video(
                    label="src_video",
                    sources=['upload'],
                    value=None,
                    interactive=True)
            with gr.Column(scale=1, min_width=0):
                self.src_mask = gr.Video(
                    label="src_mask",
                    sources=['upload'],
                    value=None,
                    interactive=True)
        #
        with gr.Row(variant='panel', equal_height=True):
            with gr.Column(scale=1, min_width=0):
                with gr.Row(equal_height=True):
                    self.src_ref_image_1 = gr.Image(
                        label='src_ref_image_1',
                        height=200,
                        interactive=True,
                        type='filepath',
                        image_mode='RGB',
                        sources=['upload'],
                        elem_id="src_ref_image_1",
                        format='png')
                    self.src_ref_image_2 = gr.Image(
                        label='src_ref_image_2',
                        height=200,
                        interactive=True,
                        type='filepath',
                        image_mode='RGB',
                        sources=['upload'],
                        elem_id="src_ref_image_2",
                        format='png')
                    self.src_ref_image_3 = gr.Image(
                        label='src_ref_image_3',
                        height=200,
                        interactive=True,
                        type='filepath',
                        image_mode='RGB',
                        sources=['upload'],
                        elem_id="src_ref_image_3",
                        format='png')
        with gr.Row(variant='panel', equal_height=True):
            with gr.Column(scale=1):
                self.prompt = gr.Textbox(
                    show_label=False,
                    placeholder="positive_prompt_input",
                    elem_id='positive_prompt',
                    container=True,
                    autofocus=True,
                    elem_classes='type_row',
                    visible=True,
                    lines=2)
                self.negative_prompt = gr.Textbox(
                    show_label=False,
                    value=self.pipe.config.sample_neg_prompt,
                    placeholder="negative_prompt_input",
                    elem_id='negative_prompt',
                    container=True,
                    autofocus=False,
                    elem_classes='type_row',
                    visible=True,
                    interactive=True,
                    lines=1)
        #
        with gr.Row(variant='panel', equal_height=True):
            with gr.Column(scale=1, min_width=0):
                with gr.Row(equal_height=True):
                    self.shift_scale = gr.Slider(
                        label='shift_scale',
                        minimum=0.0,
                        maximum=100.0,
                        step=1.0,
                        value=16.0,
                        interactive=True)
                    self.sample_steps = gr.Slider(
                        label='sample_steps',
                        minimum=1,
                        maximum=100,
                        step=1,
                        value=25,
                        interactive=True)
                    self.context_scale = gr.Slider(
                        label='context_scale',
                        minimum=0.0,
                        maximum=2.0,
                        step=0.1,
                        value=1.0,
                        interactive=True)
                    self.guide_scale = gr.Slider(
                        label='guide_scale',
                        minimum=1,
                        maximum=10,
                        step=0.5,
                        value=5.0,
                        interactive=True)
                    self.infer_seed = gr.Slider(
                        minimum=-1, maximum=10000000, value=2025, label="Seed")
        #
        with gr.Accordion(label="Usable without source video", open=False):
            with gr.Row(equal_height=True):
                self.output_height = gr.Textbox(
                    label='resolutions_height',
                    # value=480,
                    value=720,
                    interactive=True)
                self.output_width = gr.Textbox(
                    label='resolutions_width',
                    # value=832,
                    value=1280,
                    interactive=True)
                self.frame_rate = gr.Textbox(
                    label='frame_rate', value=16, interactive=True)
                self.num_frames = gr.Textbox(
                    label='num_frames', value=81, interactive=True)
        #
        with gr.Row(equal_height=True):
            with gr.Column(scale=5):
                self.generate_button = gr.Button(
                    value='Run',
                    elem_classes='type_row',
                    elem_id='generate_button',
                    visible=True)
            with gr.Column(scale=1):
                self.refresh_button = gr.Button(value='\U0001f504')  # 🔄
        #
        self.output_gallery = gr.Gallery(
            label="output_gallery",
            value=[],
            interactive=False,
            allow_preview=True,
            preview=True)

    def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
                 src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
                 shift_scale, sample_steps, context_scale, guide_scale,
                 infer_seed, output_height, output_width, frame_rate,
                 num_frames):
        output_height, output_width, frame_rate, num_frames = int(
            output_height), int(output_width), int(frame_rate), int(num_frames)
        src_ref_images = [
            x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
            if x is not None
        ]
        src_video, src_mask, src_ref_images = self.pipe.prepare_source(
            [src_video], [src_mask], [src_ref_images],
            num_frames=num_frames,
            image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
            device=self.pipe.device)
        video = self.pipe.generate(
            prompt,
            src_video,
            src_mask,
            src_ref_images,
            size=(output_width, output_height),
            context_scale=context_scale,
            shift=shift_scale,
            sampling_steps=sample_steps,
            guide_scale=guide_scale,
            n_prompt=negative_prompt,
            seed=infer_seed,
            offload_model=True)

        name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
        video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
        video_frames = (
            torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
            255).cpu().numpy().astype(np.uint8)

        try:
            writer = imageio.get_writer(
                video_path,
                fps=frame_rate,
                codec='libx264',
                quality=8,
                macro_block_size=1)
            for frame in video_frames:
                writer.append_data(frame)
            writer.close()
            print(video_path)
        except Exception as e:
            raise gr.Error(f"Video save error: {e}")

        if self.gallery_share:
            self.gallery_share_data.add(video_path)
            return self.gallery_share_data.get()
        else:
            return [video_path]

    def set_callbacks(self, **kwargs):
        self.gen_inputs = [
            self.output_gallery, self.src_video, self.src_mask,
            self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
            self.prompt, self.negative_prompt, self.shift_scale,
            self.sample_steps, self.context_scale, self.guide_scale,
            self.infer_seed, self.output_height, self.output_width,
            self.frame_rate, self.num_frames
        ]
        self.gen_outputs = [self.output_gallery]
        self.generate_button.click(
            self.generate,
            inputs=self.gen_inputs,
            outputs=self.gen_outputs,
            queue=True)
        self.refresh_button.click(
            lambda x: self.gallery_share_data.get()
            if self.gallery_share else x,
            inputs=[self.output_gallery],
            outputs=[self.output_gallery])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Argparser for VACE-WAN Demo:\n')
    parser.add_argument(
        '--server_port', dest='server_port', help='', type=int, default=7860)
    parser.add_argument(
        '--server_name', dest='server_name', help='', default='0.0.0.0')
    parser.add_argument('--root_path', dest='root_path', help='', default=None)
    parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
    parser.add_argument(
        "--mp",
        action="store_true",
        help="Use Multi-GPUs",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="vace-14B",
        choices=list(WAN_CONFIGS.keys()),
        help="The model name to run.")
    parser.add_argument(
        "--ulysses_size",
        type=int,
        default=1,
        help="The size of the ulysses parallelism in DiT.")
    parser.add_argument(
        "--ring_size",
        type=int,
        default=1,
        help="The size of the ring attention parallelism in DiT.")
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        # default='models/VACE-Wan2.1-1.3B-Preview',
        default='models/Wan2.1-VACE-14B/',
        help="The path to the checkpoint directory.",
    )
    parser.add_argument(
        "--offload_to_cpu",
        action="store_true",
        help="Offloading unnecessary computations to CPU.",
    )

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)

    with gr.Blocks() as demo:
        infer_gr = VACEInference(
            args, skip_load=False, gallery_share=True, gallery_share_limit=5)
        infer_gr.create_ui()
        infer_gr.set_callbacks()
        allowed_paths = [args.save_dir]
        demo.queue(status_update_rate=1).launch(
            server_name=args.server_name,
            server_port=args.server_port,
            root_path=args.root_path,
            allowed_paths=allowed_paths,
            show_error=True,
            debug=True)


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "wan"
version = "2.1.0"
description = "Wan: Open and Advanced Large-Scale Video Generative Models"
authors = [
    { name = "Wan Team", email = "wan.ai@alibabacloud.com" }
]
license = { file = "LICENSE.txt" }
readme = "README.md"
requires-python = ">=3.10,<4.0"
dependencies = [
    "torch>=2.4.0",
    "torchvision>=0.19.0",
    "opencv-python>=4.9.0.80",
    "diffusers>=0.31.0",
    "transformers>=4.49.0",
    "tokenizers>=0.20.3",
    "accelerate>=1.1.1",
    "tqdm",
    "imageio",
    "easydict",
    "ftfy",
    "dashscope",
    "imageio-ffmpeg",
    "flash_attn",
    "gradio>=5.0.0",
    "numpy>=1.23.5,<2"
]

[project.optional-dependencies]
dev = [
    "pytest",
    "black",
    "flake8",
    "isort",
    "mypy",
    "huggingface-hub[cli]"
]

[project.urls]
homepage = "https://wanxai.com"
documentation = "https://github.com/Wan-Video/Wan2.1"
repository = "https://github.com/Wan-Video/Wan2.1"
huggingface = "https://huggingface.co/Wan-AI/"
modelscope = "https://modelscope.cn/organization/Wan-AI"
discord = "https://discord.gg/p5XbdQV7"

[tool.setuptools]
packages = ["wan"]

[tool.setuptools.package-data]
"wan" = ["**/*.py"]

[tool.black]
line-length = 88

[tool.isort]
profile = "black"

[tool.mypy]
strict = true




================================================
FILE: requirements.txt
================================================
torch>=2.4.0
torchvision>=0.19.0
opencv-python>=4.9.0.80
diffusers>=0.31.0
transformers>=4.49.0
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
imageio
easydict
ftfy
dashscope
imageio-ffmpeg
flash_attn
gradio>=5.0.0
numpy>=1.23.5,<2
mediapy

================================================
FILE: run_example.sh
================================================
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#!/usr/bin/env bash
set -euo pipefail

usage() {
  cat <<EOF
Usage: $0 -c <ckpt_dir> [-g <num_gpus>]
  -c  Path to your model checkpoint directory
  -g  Number of GPUs to use (defaults to all available GPUs)
  -p  Path to prompt file
  -o  Path to output location
EOF
  exit 1
}

OUTPUT_DIR="samples"
# parse args
CKPT_DIR=""
PROMPT="examples/test.yaml"
NGPUS=""
while [[ $# -gt 0 ]]; do
  case $1 in
    -c|--ckpt_dir)
      CKPT_DIR="$2"; shift 2;;
    -g|--gpus)
      NGPUS="$2"; shift 2;;
    -p|--prompt)
      PROMPT="$2"; shift 2;;
    -o|--output)
      OUTPUT_DIR="$2"; shift 2;;
    -*)
      echo "Unknown option: $1" >&2; usage;;
    *)
      break;;
  esac
done


if [[ -z "$CKPT_DIR" ]]; then
  echo "Error: --ckpt_dir is required" >&2
  usage
fi

# detect GPUs if not provided
if [[ -z "$NGPUS" ]]; then
  if command -v python3 &>/dev/null; then
    NGPUS=$(python3 - <<'PYCODE'
import torch
print(torch.cuda.device_count() or 1)
PYCODE
)
  else
    echo "Warning: python3 not found; defaulting to 1 GPU" >&2
    NGPUS=1
  fi
fi

echo ">>> Using checkpoint: $CKPT_DIR"
echo ">>> Generate case: $PROMPT"
echo ">>> Saved to: $OUTPUT_DIR"
echo ">>> Detected $NGPUS GPU(s)"

mkdir -p $OUTPUT_DIR/outputs

if [[ "$NGPUS" -eq 1 ]]; then
  echo ">>> Single‐GPU mode: running generate.py directly"
  python generate.py \
    --ckpt_dir "$CKPT_DIR" \
    --prompt $PROMPT \
    --save_file "$OUTPUT_DIR/outputs/%03d.mp4"
else
  echo ">>> Multi‐GPU mode: launching with torchrun"
  torchrun \
    --nproc_per_node="$NGPUS" \
    --master-port=5645 \
    generate.py \
      --ckpt_dir "$CKPT_DIR" \
      --prompt $PROMPT \
      --save_file "$OUTPUT_DIR/outputs/%03d.mp4" \
      --ulysses_size "$NGPUS" \
      --base_seed 4567 \
      --dit_fsdp \
      --t5_fsdp
fi

cp $PROMPT "$OUTPUT_DIR/" &

# visualize results
python3 ./tools/visualize_trajectory.py --base_dir "$OUTPUT_DIR/"
python3 ./tools/plot_user_inputs.py $PROMPT --save_dir $OUTPUT_DIR/image_with_tracks


================================================
FILE: tests/README.md
================================================

Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.

```bash
bash ./test.sh <local model dir> <gpu number>
```


================================================
FILE: tests/test.sh
================================================
#!/bin/bash


if [ "$#" -eq 2 ]; then
  MODEL_DIR=$(realpath "$1")
  GPUS=$2
else
  echo "Usage: $0 <local model dir> <gpu number>"
  exit 1
fi

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT" || exit 1

PY_FILE=./generate.py


function t2v_1_3B() {
    T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"

    # 1-GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
    python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR

    # Multiple GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS

    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"

    if [ -n "${DASH_API_KEY+x}" ]; then
        echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
        torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
    else
        echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
    fi
}

function t2v_14B() {
    T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"

    # 1-GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
    python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR

    # Multiple GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS

    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}



function t2i_14B() {
    T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"

    # 1-GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
    python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR

    # Multiple GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS

    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}


function i2v_14B_480p() {
    I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"

    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
    python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR

    # Multiple GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS

    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"

    if [ -n "${DASH_API_KEY+x}" ]; then
        echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
        torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
    else
        echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
    fi
}


function i2v_14B_720p() {
    I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"

    # 1-GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
    python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR

    # Multiple GPU Test
    echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
    torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}

function vace_1_3B() {
    VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
    torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR

}


t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
vace_1_3B


================================================
FILE: tools/get_track_from_videos.py
================================================
import torch
from typing import List, Sequence, Any
from PIL import Image
import numpy as np
import cv2
import yaml
import math
import io


QUANT_MULTI = 8
def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI):
    # pack into uint16 as before
    arr_q = (quant_multi * arr).astype(np.float32)
    bio = io.BytesIO()
    if compressed:
        np.savez_compressed(bio, array=arr_q)
    else:
        np.savez(bio, array=arr_q)
    torch.save(bio.getvalue(), path)


def parse_to_list(text: str) -> List[List[int]]:
    """
    Parse a multiline string of comma-separated integers into a list of integer lists.

    Example:
        text = "327, 806, 670, 1164\n49, 587, 346, 1037"
        parse_to_list(text)
        # → [[327, 806, 670, 1164], [49, 587, 346, 1037]]
    """
    lines = text.strip().splitlines()
    result: List[List[int]] = []
    for line in lines:
        # split on comma, strip whitespace, convert to int
        nums = [int(x.strip()) for x in line.split(',') if x.strip()]
        if nums:
            result.append(nums)
    return result


def load_video_to_frames(
    video_path: str,
    preset_fps: float = 24,
    max_short_edge: int = None
) -> List[Image.Image]:
    """
    Load a video file, resample its frame-rate to a single preset value
    (if needed), optionally resize frames so their short edge is at most
    max_short_edge (keeping aspect ratio), and return a list of PIL.Image frames.

    Args:
        video_path (str): Path to the video file.
        preset_fps (float): Desired FPS. If the video's FPS isn't exactly
            this value, the video will be resampled to match it.
        max_short_edge (int, optional): If provided and a frame's short edge
            (min(width,height)) exceeds this, the frame is resized so the
            short edge == max_short_edge, preserving aspect ratio.

    Returns:
        List[PIL.Image.Image]: A list of frames at the preset FPS, each
            resized if needed.
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Unable to open video file: {video_path}")

    fps_in = cap.get(cv2.CAP_PROP_FPS)
    do_resample = fps_in > 0 and abs(fps_in - preset_fps) > 1e-3

    # read all frames
    raw_frames: List[Image.Image] = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # BGR -> RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame_rgb)

        # optional resize by short edge
        if max_short_edge is not None:
            w, h = img.size
            short_edge = min(w, h)
            if short_edge > max_short_edge:
                scale = max_short_edge / short_edge
                new_w = int(round(w * scale))
                new_h = int(round(h * scale))
                img = img.resize((new_w, new_h), resample=Image.LANCZOS)

        raw_frames.append(img)
    cap.release()

    # resample FPS if needed
    if do_resample:
        ratio = fps_in / preset_fps
        total_in = len(raw_frames)
        total_out = int(math.floor(total_in / ratio))
        resampled: List[Image.Image] = []
        for i in range(total_out):
            idx = min(int(round(i * ratio)), total_in - 1)
            resampled.append(raw_frames[idx])
        return resampled

    return raw_frames


def sample_grid_points(bbox, N):
    """
    Uniformly sample N points inside a bounding box using a grid
    whose Nx×Ny layout follows the box’s width:height ratio.

    Args:
        bbox: tuple (ymin, xmin, ymax, xmax)
        N:     int, number of points to sample

    Returns:
        numpy.ndarray of shape (N, 2), each row is (y, x)
    """
    xmin, ymin, xmax, ymax = bbox
    width = xmax - xmin
    height = ymax - ymin

    # choose Nx and Ny so that Nx/Ny ≈ width/height and Nx*Ny >= N
    Nx = int(np.ceil(np.sqrt(N * width / height)))
    Ny = int(np.ceil(np.sqrt(N * height / width)))

    # generate evenly spaced coordinates along each axis
    ys = np.linspace(ymin, ymax, Ny)
    xs = np.linspace(xmin, xmax, Nx)

    # form the grid and flatten
    yy, xx = np.meshgrid(ys, xs, indexing='ij')
    coords = np.stack([yy.ravel(), xx.ravel()], axis=1)

    # return exactly N samples
    return coords



def resize_images_to_size(image_list, size=1024):
    """
    Given a list of PIL Image objects, resize each so that
    width and height are multiples of 16, using nearest multiple rounding.
    Returns a new list of resized images.
    """
    resized_list = []
    for img in image_list:        
        # Resize using a high-quality resample filter (e.g. LANCZOS).
        # You can also use Image.BILINEAR, Image.BICUBIC, etc.
        resized_img = img.resize((size, size), resample=Image.LANCZOS)
        resized_list.append(resized_img)
    
    return resized_list


def resize_box(box, ratios):
    return [int(round(box[0] * ratios[0])), int(round(box[1] * ratios[1])), int(round(box[2] * ratios[0])), int(round(box[3] * ratios[1]))]


class TrackAnyPoint():
    def __init__(self, n_points=60):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.n_points = n_points
        self.model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(self.device)
        self.resolution = 720
        self.boundary_remove = 40

    @torch.no_grad()
    def __call__(self, video_frames: List[Image.Image]):
        ori_w, ori_h = video_frames[0].size
        video_frames = resize_images_to_size(video_frames, size=self.resolution)

        boxes = [[self.boundary_remove, self.boundary_remove, video_frames[0].size[0] - self.boundary_remove, video_frames[0].size[1] - self.boundary_remove]]

        representative_points = [torch.from_numpy(sample_grid_points(box, int(self.n_points / len(boxes)))).to(self.device) for box in boxes]
        representative_points = torch.cat(representative_points, dim=0)
        representative_points = torch.cat([torch.zeros_like(representative_points[..., :1]), representative_points], dim=-1)
        frames_np = [np.array(frame) for frame in video_frames]

        get_trackers = self.inference(np.array(frames_np), ori_w, ori_h, representative_points[None])

        return get_trackers

    @torch.no_grad()
    def inference(self, frames: np.ndarray, w_ori, h_ori, tracks) -> np.ndarray:
        video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(self.device)  # B T C H W
        _, _, _, H, W = video.shape

        tracks = tracks.float()

        # Run inference. The forward now returns a mapping, e.g., with key 'pred_tracks'.
        tracks, visibles = self.model(video, tracks)
        tracks = convert_grid_coordinates(tracks, (W, H), (w_ori, h_ori),)

        return torch.cat([tracks, visibles.unsqueeze(-1)], dim=-1).cpu().numpy()


def convert_grid_coordinates(
    coords: torch.Tensor,
    input_grid_size: Sequence[int],
    output_grid_size: Sequence[int],
    coordinate_format: str = 'xy',
) -> torch.Tensor:
    """
    Convert image coordinates between image grids of different sizes using PyTorch.

    By default, the function assumes that the image corners are aligned.
    It scales the coordinates from the input grid to the output grid by multiplying
    by the size ratio.

    Args:
        coords (torch.Tensor): The coordinates to be converted.
            For 'xy', the tensor should have shape [..., 2].
            For 'tyx', the tensor should have shape [..., 3].
        input_grid_size (Sequence[int]): The size of the current grid.
            For 'xy', it should be [width, height].
            For 'tyx', it should be [num_frames, height, width].
        output_grid_size (Sequence[int]): The size of the target grid.
            For 'xy', it should be [width, height].
            For 'tyx', it should be [num_frames, height, width].
        coordinate_format (str): Either 'xy' (default) or 'tyx'.

    Returns:
        torch.Tensor: The transformed coordinates with the same shape as `coords`.

    Raises:
        ValueError: If grid sizes don't match the expected lengths for the given coordinate format,
                    or if frame counts (for 'tyx') differ.
    """
    # Convert grid sizes to torch tensors with the same dtype and device as coords.
    if isinstance(input_grid_size, (tuple, list)):
        input_grid_size = torch.tensor(input_grid_size, dtype=coords.dtype, device=coords.device)
    if isinstance(output_grid_size, (tuple, list)):
        output_grid_size = torch.tensor(output_grid_size, dtype=coords.dtype, device=coords.device)
    
    # Validate the grid sizes based on coordinate_format.
    if coordinate_format == 'xy':
        if input_grid_size.numel() != 2 or output_grid_size.numel() != 2:
            raise ValueError("For 'xy' format, grid sizes must have 2 elements.")
    elif coordinate_format == 'tyx':
        if input_grid_size.numel() != 3 or output_grid_size.numel() != 3:
            raise ValueError("For 'tyx' format, grid sizes must have 3 elements.")
        if input_grid_size[0] != output_grid_size[0]:
            raise ValueError("Converting frame count is not supported.")
    else:
        raise ValueError("Recognized coordinate formats are 'xy' and 'tyx'.")
    
    # Compute the transformed coordinates.
    # Broadcasting will apply elementwise division and multiplication.
    transformed_coords = coords * (output_grid_size / input_grid_size)
    
    return transformed_coords


def save_frames_to_mp4(frames, output_path, fps=24, codec='mp4v'):
    """
    Save a list of PIL.Image frames as an MP4 video.

    Args:
        frames (List[PIL.Image.Image]): List of PIL Image frames.
        output_path (str): Path to the output .mp4 file.
        fps (int, optional): Frames per second. Defaults to 24.
        codec (str, optional): FourCC codec code (e.g., 'mp4v', 'H264'). Defaults to 'mp4v'.

    Raises:
        ValueError: If `frames` is empty.
    """
    if not frames:
        raise ValueError("No frames to save.")

    # Ensure all frames are the same size
    width, height = frames[0].size

    # Prepare video writer
    fourcc = cv2.VideoWriter_fourcc(*codec)
    writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for img in frames:
        # Resize if needed
        if img.size != (width, height):
            img = img.resize((width, height), Image.LANCZOS)

        # Convert PIL Image (RGB) to BGR array for OpenCV
        frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        writer.write(frame)

    writer.release()


def save_yaml(
    data: Any,
    filename: str,
    *,
    default_flow_style: bool = False,
    sort_keys: bool = False
) -> None:
    """
    Save a Python object to a YAML file. 

    If the file already exists, appends the data as a new YAML document
    (with a leading '---' separator). Otherwise creates a fresh file.

    Args:
        data: The Python object (e.g., dict, list) to serialize.
        filename: Path to the output .yaml file.
        default_flow_style: If False (the default), uses block style.
        sort_keys: If True, sorts dictionary keys in the output.
    """
    # choose append mode if file exists
    mode = 'w'
    with open(filename, mode, encoding='utf-8') as f:
        if mode == 'a':
            # separate from prior content and start a new document
            f.write('\n')
        yaml.safe_dump(
            data,
            f,
            default_flow_style=default_flow_style,
            sort_keys=sort_keys,
            allow_unicode=True,
            explicit_start=True
        )


if __name__ == "__main__":
    import os
    import argparse

    parser = argparse.ArgumentParser(
        description="V2V motion transfer."
    )
    parser.add_argument("--source_folder", help="Input path to video files", type=str)
    parser.add_argument("--save_folder", help="Output path", type=str)
    parser.add_argument("--num_points", help="Number of tracking points", default=40, type=int)
    args = parser.parse_args()

    n_points = args.num_points
    source_video_folder = args.source_folder
    save_loc = args.save_folder

    os.makedirs(os.path.join(save_loc, 'tracks'), exist_ok=True)
    os.makedirs(os.path.join(save_loc, 'videos'), exist_ok=True)
    os.makedirs(os.path.join(save_loc, 'images'), exist_ok=True)

    model_ = TrackAnyPoint(n_points=n_points)

    t_ll = 121

    kk = 0
    out_list = []

    for fl in os.listdir(source_video_folder):
        frames = load_video_to_frames(os.path.join(source_video_folder, fl))

        frames = frames + [frames[-1]]

        f_len = len(frames)

        print('Processing:', fl)

        for ttt in range(f_len // t_ll):
            if ttt > 0:
                continue
            images = frames[ttt * t_ll:(1 + ttt) * t_ll]

            save_frames_to_mp4(images, os.path.join(save_loc, 'videos', f'{kk}.mp4'))

            image = np.array(images[0])
            images[0].save(os.path.join(save_loc, 'images', f'{kk}.png'))

            caption = ''
            tracks = model_(images)
            tracks = np.transpose(tracks, (2, 1, 0, 3))
            tracks_bytes = array_to_npz_bytes(tracks, os.path.join(save_loc, 'tracks', f'{kk}.pth'), compressed=True)

            out_list.append(
                {
                    'track': os.path.join(save_loc, 'tracks', f'{kk}.pth'),
                    'text': caption,
                    'image': os.path.join(save_loc, 'images', f'{kk}.png'),
                }
            )
            kk += 1

    save_yaml(out_list, os.path.join(save_loc, 'test.yaml'))


================================================
FILE: tools/plot_user_inputs.py
================================================
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image, ImageDraw
import numpy as np
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
import io
import yaml, argparse, os
import math


def plot_tracks(
    img: Image.Image,
    tracks: np.ndarray,
    line_width: int = 10,
    dot_radius: int = 10,
    arrow_length: int = 25,
    arrow_angle_deg: float = 30.0
) -> Image.Image:
    """
    Plot trajectories on an image, with a dot at the start and an arrow whose center
    aligns with the last visible trajectory point.

    Args:
        img: A PIL Image.
        tracks: Array of shape (N, T, 1, 3): (x, y, visibility).
        line_width: Thickness of trajectory lines.
        dot_radius: Radius of the start dot.
        arrow_length: Length of each arrowhead side.
        arrow_angle_deg: Angle between shaft and arrowhead sides (degrees).
    """
    canvas = img.convert("RGB")
    draw = ImageDraw.Draw(canvas)

    N, T, _, _ = tracks.shape
    arrow_angle = math.radians(arrow_angle_deg)

    for i in range(N):
        traj = tracks[i, :, 0, :]
        if traj.shape[-1] == 4:
            traj = np.concatenate([traj[..., :2], traj[..., -1:]], axis=-1)
        # Draw segments
        for t in range(T - 1):
            x1, y1, v1 = traj[t]
            x2, y2, v2 = traj[t + 1]
            if v1 == 0 or v2 == 0:
                continue
            ratio = t / (T - 1)
            color = (int(255 * ratio), int(255 * (1 - ratio)), 30)
            draw.line([(int(x1), int(y1)), (int(x2), int(y2))],
                      fill=color, width=line_width)

        # Visible indices
        visible = [t for t in range(T) if traj[t, 2] == 1]
        if not visible:
            continue

        # Start dot
        t0 = visible[0]
        x0, y0, _ = traj[t0]
        draw.ellipse([
            (int(x0 - dot_radius), int(y0 - dot_radius)),
            (int(x0 + dot_radius), int(y0 + dot_radius))
        ], fill=(0, 255, 30))

        # Arrow at end
        t_last = visible[-1]
        ratio_last = t_last / (T - 1)
        arrow_color = (int(255 * ratio_last), int(255 * (1 - ratio_last)), 30)

        # Direction: average of last two segments if available
        if len(visible) >= 3:
            t2, t1, tL = visible[-3], visible[-2], visible[-1]
            x2, y2, _ = traj[t2]
            x1, y1, _ = traj[t1]
            xL, yL, _ = traj[tL]
            v1 = (x1 - x2, y1 - y2)
            v2 = (xL - x1, yL - y1)
            dx, dy = (v1[0] + v2[0]) / 2, (v1[1] + v2[1]) / 2
        else:
            x1, y1, _ = traj[visible[-2]]
            xL, yL, _ = traj[t_last]
            dx, dy = xL - x1, yL - y1

        dist = math.hypot(dx, dy)
        if dist < 1e-3:
            continue
        ux, uy = dx / dist, dy / dist

        # Arrowhead points
        def rotate(vx, vy, ang):
            return vx * math.cos(ang) - vy * math.sin(ang), vx * math.sin(ang) + vy * math.cos(ang)

        vx1, vy1 = rotate(ux, uy,  arrow_angle)
        vx2, vy2 = rotate(ux, uy, -arrow_angle)
        p1 = (xL - vx1 * arrow_length, yL - vy1 * arrow_length)
        p2 = (xL - vx2 * arrow_length, yL - vy2 * arrow_length)

        # Compute translation to center triangle on (xL, yL)
        cx = (xL + p1[0] + p2[0]) / 3
        cy = (yL + p1[1] + p2[1]) / 3
        dx_c, dy_c = xL - cx, yL - cy

        tip = (xL + dx_c, yL + dy_c)
        p1_c = (p1[0] + dx_c, p1[1] + dy_c)
        p2_c = (p2[0] + dx_c, p2[1] + dy_c)

        draw.polygon([tip, p1_c, p2_c], fill=arrow_color)

    return canvas


def unzip_to_array(
    data: bytes, key: Union[str, List[str]] = "array"
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
    bytes_io = io.BytesIO(data)

    if isinstance(key, str):
        # Load the NPZ data from the BytesIO object
        with np.load(bytes_io) as data:
            return data[key]
    else:
        get = {}
        with np.load(bytes_io) as data:
            for k in key:
                get[k] = data[k]
        return get


def main():
    parser = argparse.ArgumentParser(description="Plot trajectories on images")
    parser.add_argument("base_file", help="Path to YAML file describing images and tracks")
    parser.add_argument("--save_dir", default='', type=str, help="Path save images")
    args = parser.parse_args()

    # Load YAML list of dicts
    with open(args.base_file, 'r') as f:
        items = yaml.safe_load(f)  # List[Dict]

    for ii, item in enumerate(items):
        image_path = item["image"]
        track_path = item["track"]

        # Load image and tracks
        img = Image.open(image_path)
        raw_tracks = torch.load(track_path)
        tracks = unzip_to_array(raw_tracks) / 8

        # import ipdb; ipdb.set_trace()

        # Plot trajectories
        try:
            out_img = plot_tracks(img, tracks,)
        except Exception as e:
            print(f"Error plotting tracks for {image_path}: {e}")
            continue
        
        if not args.save_dir:
            # Determine output path
            out_path = image_path.replace("/images/", "/images_track_input/")
        else:
            out_path = os.path.join(args.save_dir, f'{ii:02d}.jpg')
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

        # Save output image
        out_img.save(out_path)
        print(f"Saved plotted image to {out_path}")


if __name__ == "__main__":
    main()


================================================
FILE: tools/trajectory_editor/app.py
================================================
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib.pyplot as plt
from flask import Flask, request, jsonify, render_template
import os
import io
import numpy as np
import torch
import yaml
import matplotlib
import argparse
matplotlib.use('Agg')

# Warning: app.py shall only run on localhost, as running on remote server may cause sercuity issue
app = Flask(__name__, static_folder='static', template_folder='templates')


# ——— Arguments ———————————————————————————————————
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='videos_example')
args = parser.parse_args()


# ——— Configuration —————————————————————————————
BASE_DIR = args.save_dir
STATIC_BASE = os.path.join('static', BASE_DIR)
IMAGES_DIR = os.path.join(STATIC_BASE, 'images')
OVERLAY_DIR = os.path.join(STATIC_BASE, 'images_tracks')
TRACKS_DIR = os.path.join(BASE_DIR, 'tracks')
YAML_PATH = os.path.join(BASE_DIR, 'test.yaml')
IMAGES_DIR_OUT = os.path.join(BASE_DIR, 'images')

FIXED_LENGTH = 121
COLOR_CYCLE = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
QUANT_MULTI = 8

for d in (IMAGES_DIR, TRACKS_DIR, OVERLAY_DIR, IMAGES_DIR_OUT):
    os.makedirs(d, exist_ok=True)

# ——— Helpers ———————————————————————————————————————


def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI):
    # pack into uint16 as before
    arr_q = (quant_multi * arr).astype(np.float32)
    bio = io.BytesIO()
    if compressed:
        np.savez_compressed(bio, array=arr_q)
    else:
        np.savez(bio, array=arr_q)
    torch.save(bio.getvalue(), path)


def load_existing_tracks(path):
    raw = torch.load(path)
    bio = io.BytesIO(raw)
    with np.load(bio) as npz:
        return npz['array']

# ——— Routes ———————————————————————————————————————


@app.route('/')
def index():
    return render_template('index.html')


@app.route('/upload_image', methods=['POST'])
def upload_image():
    f = request.files['image']
    from PIL import Image
    img = Image.open(f.stream)
    orig_w, orig_h = img.size

    idx = len(os.listdir(IMAGES_DIR)) + 1
    ext = f.filename.rsplit('.', 1)[-1]
    fname = f"{idx:02d}.{ext}"
    img.save(os.path.join(IMAGES_DIR, fname))
    img.save(os.path.join(IMAGES_DIR_OUT, fname))

    return jsonify({
        'image_url': f"{STATIC_BASE}/images/{fname}",
        'image_id': idx,
        'ext': ext,
        'orig_width': orig_w,
        'orig_height': orig_h
    })


@app.route('/store_tracks', methods=['POST'])
def store_tracks():
    data = request.get_json()
    image_id = data['image_id']
    ext = data['ext']
    free_tracks = data.get('tracks', [])
    circ_trajs = data.get('circle_trajectories', [])

    # Debug lengths
    for i, tr in enumerate(free_tracks, 1):
        print(f"Freehand Track {i}: {len(tr)} points")
    for i, tr in enumerate(circ_trajs, 1):
        print(f"Circle/Static Traj {i}: {len(tr)} points")

    def pad_pts(tr):
        """Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
        pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
        n = pts.shape[0]
        if n < FIXED_LENGTH:
            pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
            pts = np.vstack((pts, pad))
        else:
            pts = pts[:FIXED_LENGTH]
        return pts.reshape(FIXED_LENGTH, 1, 3)

    arrs = []

    # 1) Freehand tracks
    for i, tr in enumerate(free_tracks):
        pts = pad_pts(tr)
        arrs.append(pts,)

    # 2) Circle + Static combined
    for i, tr in enumerate(circ_trajs):
        pts = pad_pts(tr)

        arrs.append(pts)
    print(arrs)
    # Nothing to save?
    if not arrs:
        overlay_file = f"{image_id:02d}.png"
        return jsonify({
            'status': 'ok',
            'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
        })

    new_tracks = np.stack(arrs, axis=0)  # (T_new, FIXED_LENGTH,1,4)

    # Load existing .pth and pad old channels to 4 if needed
    track_path = os.path.join(TRACKS_DIR, f"{image_id:02d}.pth")
    if os.path.exists(track_path):
        # shape (T_old, FIXED_LENGTH,1,3) or (...,4)
        old = load_existing_tracks(track_path)
        if old.ndim == 4 and old.shape[-1] == 3:
            pad = np.zeros(
                (old.shape[0], old.shape[1], old.shape[2], 1), dtype=np.float32)
            old = np.concatenate((old, pad), axis=-1)
        all_tracks = np.concatenate([old, new_tracks], axis=0)
    else:
        all_tracks = new_tracks

    # Save updated track file
    array_to_npz_bytes(all_tracks, track_path, compressed=True)

    # Build overlay PNG
    img_path = os.path.join(IMAGES_DIR, f"{image_id:02d}.{ext}")
    img = plt.imread(img_path)
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(img)
    for t in all_tracks:
        coords = t[:, 0, :]  # (FIXED_LENGTH,4)
        ax.plot(coords[:, 0][coords[:, 2] > 0.5], coords[:, 1]
                [coords[:, 2] > 0.5], marker='o', color=COLOR_CYCLE[0])
    ax.axis('off')
    overlay_file = f"{image_id:02d}.png"
    fig.savefig(os.path.join(OVERLAY_DIR, overlay_file),
                bbox_inches='tight', pad_inches=0)
    plt.close(fig)

    # Update YAML (unchanged)
    entry = {
        "image": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/images/{image_id:02d}.{ext}"),
        "text": None,
        "track": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/tracks/{image_id:02d}.pth")
    }
    if os.path.exists(YAML_PATH):
        with open(YAML_PATH) as yf:
            docs = yaml.safe_load(yf) or []
    else:
        docs = []

    for e in docs:
        if e.get("image", "").endswith(f"{image_id:02d}.{ext}"):
            e.update(entry)
            break
    else:
        docs.append(entry)

    with open(YAML_PATH, 'w') as yf:
        yaml.dump(docs, yf, default_flow_style=False)

    return jsonify({
        'status': 'ok',
        'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}"
    })


def ensure_localhost():
    """
    Verify that the application is running on localhost.
    This inspects the host's IP addresses and exits if any non-loopback
    interface is found. If the hostname cannot be resolved, the check is skipped.
    """
    import sys
    import socket
    try:
        addresses = {info[4][0]
                     for info in socket.getaddrinfo(socket.gethostname(), None)}
    except socket.gaierror:
        # Hostname not resolvable—skip this check
        return

    for addr in addresses:
        if addr not in ("127.0.0.1", "::1"):
            sys.exit(
                "SecurityError: The application must run on localhost (127.0.0.1); "
                "other network interfaces pose security risks."
            )


if __name__ == '__main__':
    ensure_localhost()
    app.run(host="127.0.0.1", port=5000)


================================================
FILE: tools/trajectory_editor/templates/index.html
================================================
<!-- Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

<!DOCTYPE html>
<html>
<head>
  <meta charset="utf-8">
  <title>Track Point Editor</title>
  <style>
    .btn-row {
      display: flex;
      align-items: center;
      margin: 8px 0;
    }
    .btn-row > * { margin-right: 12px; }
    body { font-family: sans-serif; margin: 16px; }
    #topControls, #bottomControls { margin-bottom: 12px; }
    button, input, select, label { margin: 4px; }
    #canvas { border:1px solid #ccc; display: block; margin: auto; }
    #canvas { cursor: crosshair; }
    #trajProgress { width: 200px; height: 16px; margin-left:12px; }
  </style>
</head>
<body>
  <h2>Track Point Editor</h2>

  <!-- Top controls -->
  <div id="topControls" class="btn-row">
    <input type="file" id="fileInput" accept="image/*">
    <button id="storeBtn">Store Tracks</button>
  </div>

  <!-- Main drawing canvas -->
  <canvas id="canvas"></canvas>

  <!-- Track controls -->
  <div id="bottomControls">
    <div class="btn-row">
      <button id="addTrackBtn">Add Freehand Track</button>
      <button id="deleteLastBtn">Delete Last Track</button>
      <progress id="trajProgress" max="121" value="0" style="display:none;"></progress>
    </div>
    <div class="btn-row">
      <button id="placeCircleBtn">Place Circle</button>
      <button id="addCirclePointBtn">Add Circle Point</button>
      <label>Radius:
        <input type="range" id="radiusSlider" min="10" max="800" value="50" style="display:none;">
      </label>
    </div>
    <div class="btn-row">
      <button id="addStaticBtn">Add Static Point</button>
      <label>Static Frames:
        <input type="number" id="staticFramesInput" value="121" min="1" style="width:60px">
      </label>
    </div>
    <div class="btn-row">
      <select id="trackSelect" style="min-width:160px;"></select>
      <div id="colorIndicator"
            style="
              width:16px;
              height:16px;
              border:1px solid #444;
              display:inline-block;
              vertical-align:middle;
              margin-left:8px;
              pointer-events:none;
              visibility:hidden;
            ">
      </div>
      <button id="deleteTrackBtn">Delete Selected</button>
      <button id="editTrackBtn">Edit Track</button>
      <button id="duplicateTrackBtn">Duplicate Track</button>
    </div>
    <!-- Global motion offset -->
    <div class="btn-row">
      <label>Motion X (px/frame):
        <input type="number" id="motionXInput" value="0" style="width:60px">
      </label>
      <label>Motion Y (px/frame):
        <input type="number" id="motionYInput" value="0" style="width:60px">
      </label>
      <button id="applySelectedMotionBtn">Add to Selected</button>
      <button id="applyAllMotionBtn">Add to All</button>
    </div>
  </div>

  
  <script>
  // ——— DOM refs —————————————————————————————————————————
  const canvas            = document.getElementById('canvas'),
        ctx               = canvas.getContext('2d'),
        fileIn            = document.getElementById('fileInput'),
        storeBtn          = document.getElementById('storeBtn'),
        addTrackBtn       = document.getElementById('addTrackBtn'),
        deleteLastBtn     = document.getElementById('deleteLastBtn'),
        placeCircleBtn    = document.getElementById('placeCircleBtn'),
        addCirclePointBtn = document.getElementById('addCirclePointBtn'),
        addStaticBtn      = document.getElementById('addStaticBtn'),
        staticFramesInput = document.getElementById('staticFramesInput'),
        radiusSlider      = document.getElementById('radiusSlider'),
        trackSelect       = document.getElementById('trackSelect'),
        deleteTrackBtn    = document.getElementById('deleteTrackBtn'),
        editTrackBtn      = document.getElementById('editTrackBtn'),
        duplicateTrackBtn = document.getElementById('duplicateTrackBtn'),
        trajProg          = document.getElementById('trajProgress'),
        colorIndicator    = document.getElementById('colorIndicator'),
        motionXInput            = document.getElementById('motionXInput'),
        motionYInput            = document.getElementById('motionYInput'),
        applySelectedMotionBtn  = document.getElementById('applySelectedMotionBtn'),
        applyAllMotionBtn       = document.getElementById('applyAllMotionBtn');

  let img, image_id, ext, origW, origH,
      scaleX=1, scaleY=1;

  // track data
  let free_tracks = [], current_track = [], drawing=false, motionCounter=0;
  let circle=null, static_trajs=[];
  let mode='', selectedTrack=null, editMode=false, editInfo=null, duplicateBuffer=null;
  const COLORS=['red','green','blue','cyan','magenta','yellow','black'],
        FIXED_LENGTH=121,
        editSigma = 5/Math.sqrt(2*Math.log(2));

  // ——— Upload & scale image ————————————————————————————
  fileIn.addEventListener('change', async e => {
    const f = e.target.files[0]; if (!f) return;
    const fd = new FormData(); fd.append('image',f);
    const res = await fetch('/upload_image',{method:'POST',body:fd});
    const js = await res.json();
    image_id=js.image_id; ext=js.ext;
    origW=js.orig_width; origH=js.orig_height;
    if(origW>=origH){
      canvas.width=800; canvas.height=Math.round(origH*800/origW);
    } else {
      canvas.height=800; canvas.width=Math.round(origW*800/origH);
    }
    scaleX=origW/canvas.width; scaleY=origH/canvas.height;
    img=new Image(); img.src=js.image_url;
    img.onload=()=>{
      free_tracks=[]; current_track=[];
      circle=null; static_trajs=[];
      mode=selectedTrack=''; editMode=false; editInfo=null; duplicateBuffer=null;
      trajProg.style.display='none';
      radiusSlider.style.display='none';
      trackSelect.innerHTML='';
      redraw();
    };
  });

  // ——— Store tracks + depth —————————————————————————
  storeBtn.onclick = async () => {
    if(!image_id) return alert('Load an image first');
    const fh = free_tracks.map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY}))),
          ct = (circle?.trajectories||[]).map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY}))),
          st = static_trajs.map(tr=>tr.map(p=>({x:p.x*scaleX,y:p.y*scaleY})));
    const payload = {
      image_id, ext,
      tracks: fh,
      circle_trajectories: ct.concat(st)
    };
    const res = await fetch('/store_tracks',{
      method:'POST',
      headers:{'Content-Type':'application/json'},
      body: JSON.stringify(payload)
    });
    const js = await res.json();
    img.src=js.overlay_url;
    img.onload=()=>ctx.drawImage(img,0,0,canvas.width,canvas.height);

    // reset UI
    free_tracks=[]; circle=null; static_trajs=[];
    mode=selectedTrack=''; editMode=false; editInfo=null; duplicateBuffer=null;
    trajProg.style.display='none';
    radiusSlider.style.display='none';
    trackSelect.innerHTML='';
    redraw();
  };

  // ——— Control buttons —————————————————————————————
  addTrackBtn.onclick = ()=>{
    mode='free'; drawing=true; current_track=[]; motionCounter=0;
    trajProg.max=FIXED_LENGTH; trajProg.value=0;
    trajProg.style.display='inline-block';
  };
  deleteLastBtn.onclick = ()=>{
    if(drawing){
      drawing=false; current_track=[]; trajProg.style.display='none';
    } else if(free_tracks.length){
      free_tracks.pop(); updateTrackSelect(); redraw();
    }
    updateColorIndicator();
  };
  placeCircleBtn.onclick    = ()=>{ mode='placeCircle'; drawing=false; };
  addCirclePointBtn.onclick = ()=>{ if(!circle) alert('Place circle first'); else mode='addCirclePt'; };
  addStaticBtn.onclick      = ()=>{ mode='placeStatic'; };
  duplicateTrackBtn.onclick = ()=>{
    if(!selectedTrack) return alert('Select a track first');
    const arr = selectedTrack.type==='free'
              ? free_tracks[selectedTrack.idx]
              : selectedTrack.type==='circle'
                ? circle.trajectories[selectedTrack.idx]
                : static_trajs[selectedTrack.idx];
    duplicateBuffer = arr.map(p=>({x:p.x,y:p.y}));
    mode='duplicate'; canvas.style.cursor='copy';
  };

  radiusSlider.oninput = ()=>{
    if(!circle) return;
    circle.radius = +radiusSlider.value;
    circle.trajectories.forEach((traj,i)=>{
      const θ = circle.angles[i];
      traj.push({
        x: circle.cx + Math.cos(θ)*circle.radius,
        y: circle.cy + Math.sin(θ)*circle.radius
      });
    });
    if(selectedTrack?.type==='circle')
      trajProg.value = circle.trajectories[selectedTrack.idx].length;
    redraw();
  };

  deleteTrackBtn.onclick = ()=>{
    if(!selectedTrack) return;
    const {type,idx} = selectedTrack;
    if(type==='free')    free_tracks.splice(idx,1);
    else if(type==='circle'){
      circle.trajectories.splice(idx,1);
      circle.angles.splice(idx,1);
    } else {
      static_trajs.splice(idx,1);
    }
    selectedTrack=null;
    trajProg.style.display='none';
    updateTrackSelect();
    redraw();
    updateColorIndicator();
  };

  editTrackBtn.onclick = ()=>{
    if(!selectedTrack) return alert('Select a track first');
    editMode=!editMode;
    editTrackBtn.textContent = editMode?'Stop Editing':'Edit Track';
  };

  // ——— Track select & depth init —————————————————————
  function updateTrackSelect(){
    trackSelect.innerHTML='';
    free_tracks.forEach((_,i)=>{
      const o=document.createElement('option');
      o.value=JSON.stringify({type:'free',idx:i});
      o.textContent=`Point ${i+1}`;
      trackSelect.appendChild(o);
    });
    if(circle){
      circle.trajectories.forEach((_,i)=>{
        const o=document.createElement('option');
        o.value=JSON.stringify({type:'circle',idx:i});
        o.textContent=`CirclePt ${i+1}`;
        trackSelect.appendChild(o);
      });
    }
    static_trajs.forEach((_,i)=>{
      const o=document.createElement('option');
      o.value=JSON.stringify({type:'static',idx:i});
      o.textContent=`StaticPt ${i+1}`;
      trackSelect.appendChild(o);
    });
    if(trackSelect.options.length){
      trackSelect.selectedIndex=0;
      trackSelect.onchange();
    }
    updateColorIndicator();
  }

  function applyMotionToTrajectory(traj, dx, dy) {
    traj.forEach((pt, frameIdx) => {
      pt.x += dx * frameIdx;
      pt.y += dy * frameIdx;
    });
  }

  applySelectedMotionBtn.onclick = () => {
  if (!selectedTrack) {
    return alert('Please select a track first');
  }
  const dx = parseFloat(motionXInput.value) || 0;
  const dy = parseFloat(motionYInput.value) || 0;

  // pick the underlying array
  let arr = null;
  if (selectedTrack.type === 'free') {
    arr = free_tracks[selectedTrack.idx];
  } else if (selectedTrack.type === 'circle') {
    arr = circle.trajectories[selectedTrack.idx];
  } else { // 'static'
    arr = static_trajs[selectedTrack.idx];
  }

  applyMotionToTrajectory(arr, dx, dy);
  redraw();
};

// 2) Add motion to every track on the canvas
applyAllMotionBtn.onclick = () => {
  const dx = parseFloat(motionXInput.value) || 0;
  const dy = parseFloat(motionYInput.value) || 0;

  // freehand tracks
  free_tracks.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
  // circle‑based tracks
  if (circle) {
    circle.trajectories.forEach(tr => applyMotionToTrajectory(tr, dx, dy));
  }
  // static points (now will move over frames)
  static_trajs.forEach(tr => applyMotionToTrajectory(tr, dx, dy));

  redraw();
};

  trackSelect.onchange = ()=>{
    if(!trackSelect.value){
      selectedTrack=null;
      trajProg.style.display='none';
      return;
    }
    selectedTrack = JSON.parse(trackSelect.value);

    if(selectedTrack.type==='circle'){
      trajProg.style.display='inline-block';
      trajProg.max=FIXED_LENGTH;
      trajProg.value=circle.trajectories[selectedTrack.idx].length;
    } else if(selectedTrack.type==='free'){
      trajProg.style.display='inline-block';
      trajProg.max=FIXED_LENGTH;
      trajProg.value=free_tracks[selectedTrack.idx].length;
    } else {
      trajProg.style.display='none';
    }
    updateColorIndicator();
  };

  // ——— Canvas drawing ————————————————————————————————
  canvas.addEventListener('mousedown', e=>{
    const r=canvas.getBoundingClientRect(),
          x=e.clientX-r.left, y=e.clientY-r.top;

    // place circle
    if(mode==='placeCircle'){
      circle={cx:x,cy:y,radius:50,angles:[],trajectories:[]};
      radiusSlider.max=Math.min(canvas.width,canvas.height)|0;
      radiusSlider.value=50; radiusSlider.style.display='inline';
      mode=''; updateTrackSelect(); redraw(); return;
    }
    // add circle point
    if(mode==='addCirclePt'){
      const dx=x-circle.cx, dy=y-circle.cy;
      const θ=Math.atan2(dy,dx);
      const px=circle.cx+Math.cos(θ)*circle.radius;
      const py=circle.cy+Math.sin(θ)*circle.radius;
      circle.angles.push(θ);
      circle.trajectories.push([{x:px,y:py}]);
      mode=''; updateTrackSelect(); redraw(); return;
    }
    // add static
    if (mode === 'placeStatic') {
      // how many frames to “hold” the point
      const len = parseInt(staticFramesInput.value, 10) || FIXED_LENGTH;
      // duplicate the click‐point len times
      const traj = Array.from({ length: len }, () => ({ x, y }));
      // push into free_tracks so it's drawn & edited just like any freehand curve
      free_tracks.push(traj);

      // reset state
      mode = '';
      updateTrackSelect();
      redraw();
      return;
    }
    // duplicate
    if(mode==='duplicate' && duplicateBuffer){
      const orig = duplicateBuffer;
      // click defines translation by first point
      const dx = x - orig[0].x, dy = y - orig[0].y;
      const newTr = orig.map(p=>({x:p.x+dx, y:p.y+dy}));
      free_tracks.push(newTr);
      mode=''; duplicateBuffer=null; canvas.style.cursor='crosshair';
      updateTrackSelect(); redraw(); return;
    }
    // editing
    if(editMode && selectedTrack){
      const arr = selectedTrack.type==='free'
                ? free_tracks[selectedTrack.idx]
                : selectedTrack.type==='circle'
                  ? circle.trajectories[selectedTrack.idx]
                  : static_trajs[selectedTrack.idx];
      let best=0,bd=Infinity;
      arr.forEach((p,i)=>{
        const d=(p.x-x)**2+(p.y-y)**2;
        if(d<bd){ bd=d; best=i; }
      });
      editInfo={ trackType:selectedTrack.type,
                 trackIdx:selectedTrack.idx,
                 ptIdx:best,
                 startX:x, startY:y };
      return;
    }
    // freehand start
    if(mode==='free'){
      drawing=true; motionCounter=0;
      current_track=[{x,y}];
      redraw();
    }
  });

  canvas.addEventListener('mousemove', e=>{
    const r=canvas.getBoundingClientRect(),
          x=e.clientX-r.left, y=e.clientY-r.top;
    // edit mode
    if(editMode && editInfo){
      const dx=x-editInfo.startX,
            dy=y-editInfo.startY;
      const {trackType,trackIdx,ptIdx} = editInfo;
      const arr = trackType==='free'
                ? free_tracks[trackIdx]
                : trackType==='circle'
                  ? circle.trajectories[trackIdx]
                  : static_trajs[trackIdx];
      arr.forEach((p,i)=>{
        const d=i-ptIdx;
        const w=Math.exp(-0.5*(d*d)/(editSigma*editSigma));
        p.x+=dx*w; p.y+=dy*w;
      });
      editInfo.startX=x; editInfo.startY=y;
      if(selectedTrack?.type==='circle')
        trajProg.value=circle.trajectories[selectedTrack.idx].length;
      redraw(); return;
    }
    // freehand draw
    if(drawing && (e.buttons&1)){
      motionCounter++;
      if(motionCounter%2===0){
        current_track.push({x,y});
        trajProg.value = Math.min(current_track.length, trajProg.max);
        redraw();
      }
    }
  });

  canvas.addEventListener('mouseup', ()=>{
    if(editMode && editInfo){ editInfo=null; return; }
    if(drawing){
      free_tracks.push(current_track.slice());
      drawing=false; current_track=[];
      updateTrackSelect(); redraw();
    }
  });

  function updateColorIndicator() {
    const idx = trackSelect.selectedIndex;
    if (idx < 0) {
      colorIndicator.style.visibility = 'hidden';
      return;
    }
    // Pick the color by index
    const col = COLORS[idx % COLORS.length];
    colorIndicator.style.backgroundColor = col;
    colorIndicator.style.visibility = 'visible';
  }
  
  // ——— redraw ———
  function redraw(){
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    if (img.complete) ctx.drawImage(img, 0, 0, canvas.width, canvas.height);

    // set a fatter line for all strokes
    ctx.lineWidth = 2;

    // — freehand (and static‑turned‑freehand) tracks —
    free_tracks.forEach((tr, i) => {
      const col = COLORS[i % COLORS.length];
      ctx.strokeStyle = col;
      ctx.fillStyle   = col;

      if (tr.length === 0) return;

      // check if every point equals the first
      const allSame = tr.every(p => p.x === tr[0].x && p.y === tr[0].y);

      if (allSame) {
        // draw a filled circle for a “static” dot
        ctx.beginPath();
        ctx.arc(tr[0].x, tr[0].y, 4, 0, 2 * Math.PI);
        ctx.fill();
      } else {
        // normal polyline
        ctx.beginPath();
        tr.forEach((p, j) =>
          j ? ctx.lineTo(p.x, p.y) : ctx.moveTo(p.x, p.y)
        );
        ctx.stroke();
      }
    });

    if(drawing && current_track.length){
      ctx.strokeStyle='black';
      ctx.beginPath();
      current_track.forEach((p,j)=>
        j? ctx.lineTo(p.x,p.y): ctx.moveTo(p.x,p.y));
      ctx.stroke();
    }

    // — circle trajectories —
    if (circle) {
      // circle outline
      ctx.strokeStyle = 'white';
      ctx.lineWidth   = 1;
      ctx.beginPath();
      ctx.arc(circle.cx, circle.cy, circle.radius, 0, 2 * Math.PI);
      ctx.stroke();

      circle.trajectories.forEach((tr, i) => {
        const col = COLORS[(free_tracks.length + i) % COLORS.length];
        ctx.strokeStyle = col;
        ctx.fillStyle   = col;
        ctx.lineWidth   = 2;

        if (tr.length <= 1) {
          // single‑point circle trajectory → dot
          ctx.beginPath();
          ctx.arc(tr[0].x, tr[0].y, 4, 0, 2 * Math.PI);
          ctx.fill();
        } else {
          // normal circle track
          ctx.beginPath();
          tr.forEach((p, j) =>
            j ? ctx.lineTo(p.x, p.y) : ctx.moveTo(p.x, p.y)
          );
          ctx.stroke();

          // white handle at last point
          const lp = tr[tr.length - 1];
          ctx.fillStyle = 'white';
          ctx.beginPath();
          ctx.arc(lp.x, lp.y, 4, 0, 2 * Math.PI);
          ctx.fill();
        }
      });
    }

    // — static_trajs (if you still use them separately) —
    static_trajs.forEach((tr, i) => {
      const p = tr[0];
      ctx.fillStyle = 'orange';
      ctx.beginPath();
      ctx.arc(p.x, p.y, 5, 0, 2 * Math.PI);
      ctx.fill();
    });
  }
  </script>
</body>
</html>


================================================
FILE: tools/visualize_trajectory.py
================================================
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
import mediapy as media
import torch
import os
import tqdm
import argparse
import numpy as np
import yaml
import random
import colorsys
from typing import Dict, List, Tuple, Optional
import io
from typing import Union


def unzip_to_array(
    data: bytes, key: Union[str, List[str]] = "array"
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
    bytes_io = io.BytesIO(data)

    if isinstance(key, str):
        # Load the NPZ data from the BytesIO object
        with np.load(bytes_io) as data:
            return data[key]
    else:
        get = {}
        with np.load(bytes_io) as data:
            for k in key:
                get[k] = data[k]
        return get


# Generate random colormaps for visualizing different points.
def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
  """Gets colormap for points."""
  colors = []
  for i in np.arange(0.0, 360.0, 360.0 / num_colors):
    hue = i / 360.0
    lightness = (50 + np.random.rand() * 10) / 100.0
    saturation = (90 + np.random.rand() * 10) / 100.0
    color = colorsys.hls_to_rgb(hue, lightness, saturation)
    colors.append(
        (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
    )
  random.shuffle(colors)
  return colors


def age_to_bgr(ratio: float) -> Tuple[int,int,int]:
    """
    Map ratio∈[0,1] through: 0→blue, 1/3→green, 2/3→yellow, 1→red.
    Returns (B,G,R) for OpenCV.
    """
    if ratio <= 1/3:
        # blue→green
        t = ratio / (1/3)
        b = int(255 * (1 - t))
        g = int(255 * t)
        r = 0
    elif ratio <= 2/3:
        # green→yellow
        t = (ratio - 1/3) / (1/3)
        b = 0
        g = 255
        r = int(255 * t)
    else:
        # yellow→red
        t = (ratio - 2/3) / (1/3)
        b = 0
        g = int(255 * (1 - t))
        r = 255
    return (r, g, b)


def paint_point_track(
    frames: np.ndarray,
    point_tracks: np.ndarray,
    visibles: np.ndarray,
    min_radius: int = 1,
    max_radius: int = 6,
    max_retain: int = 50
) -> np.ndarray:
    """
    Draws every past point of each track on each frame, with radius and color
    interpolated by the point's age (old→small to new→large).

    Args:
      frames:      [F, H, W, 3] uint8 RGB
      point_tracks:[N, F, 2] float32  – (x,y) in pixel coords
      visibles:    [N, F] bool        – visibility mask
      min_radius:  radius for the very first point (oldest)
      max_radius:  radius for the current point (newest)

    Returns:
      video: [F, H, W, 3] uint8 RGB
    """
    num_points, num_frames = point_tracks.shape[:2]
    H, W = frames.shape[1:3]

    video = frames.copy()

    for t in range(num_frames):
        # start from the original frame
        frame = video[t].copy()

        for i in range(num_points):
            # draw every past step τ = 0..t
            for τ in range(t + 1):
                if not visibles[i, τ]:
                    continue

                if t - τ > max_retain:
                    continue

                # sub-pixel offset + clamp
                x, y = point_tracks[i, τ] + 0.5
                xi = int(np.clip(x, 0, W - 1))
                yi = int(np.clip(y, 0, H - 1))

                # age‐ratio in [0,1]
                if num_frames > 1:
                    ratio = 1 - float(t - τ) / max_retain
                else:
                    ratio = 1.0

                # interpolated radius
                radius = int(round(min_radius + (max_radius - min_radius) * ratio))

                # OpenCV draws in BGR order:
                color_rgb = age_to_bgr(ratio)

                # filled circle
                cv2.circle(frame, (xi, yi), radius, color_rgb, thickness=-1)

        video[t] = frame

    return video


parser = argparse.ArgumentParser(
    description="Visualize tracks."
)
parser.add_argument(
    "--base_dir",
    type=str,
    default='samples',
)
parser.add_argument(
    "--video_dir",
    type=str,
    default="outputs",
)
parser.add_argument(
    "--track_dir",
    type=str,
    default="tracks",
)
parser.add_argument(
    "--output_appendix",
    type=str,
    default="_vis",
)

args = parser.parse_args()

base_dir = args.base_dir
video_dir = os.path.join(base_dir, args.video_dir)
track_dir = os.path.join(base_dir, args.track_dir)
os.makedirs(video_dir + args.output_appendix, exist_ok=True)

print([t for t in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, t))])
while len([t for t in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, t))]) == 1:
    video_dir = os.path.join(video_dir, [t for t in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, t))][0])
print("Source:", video_dir)

shift_ = 3

records = yaml.safe_load(open(os.path.join(base_dir, 'test.yaml'), 'r'))

for video_name in tqdm.tqdm(os.listdir(video_dir)):
    if '.mp4' not in video_name:
        continue
    nn = os.path.basename(video_name)
    nn = int(nn.split('.')[0] if '_' not in nn else nn.split('_')[0])

    video = media.read_video(os.path.join(video_dir, video_name))

    short_edge = min(*video.shape[1:3])
    H, W = video.shape[1:3]

    track = torch.load(records[nn]['track'])
    if isinstance(track, bytes):
        track = unzip_to_array(track)
        track = np.repeat(track, 2, axis=1)[:, ::3]
        points = track[:, :, 0, :2].astype(np.float32) / 8
        visibles = track[:, :, 0, 2].astype(np.float32) / 8

        # image_origin = os.path.join(base_dir, 'images', f'{nn:02d}.png')
        image_origin = records[nn]['image']
        image = media.read_image(image_origin)

        H_ori, W_ori, _ = image.shape
        points = points / np.array([W_ori, H_ori]) * np.array([W, H])

    else:
        points = (track[shift_:, :, :2] + track[shift_:, :, 2:4]) / 2 * short_edge + torch.tensor([W / 2, H / 2])
        visibles = track[shift_:, :, -1]

        points = torch.permute(points, (1, 0, 2)).cpu().numpy()
        visibles = torch.permute(visibles, (1, 0)).cpu().numpy()

    video_viz = paint_point_track(video, points, visibles)
    name_ = os.path.basename(video_name).split('.')[0]
    media.write_video(os.path.join(base_dir, args.video_dir + args.output_appendix, f'{name_}_viz.mp4'), video_viz, fps=16)

================================================
FILE: wan/__init__.py
================================================
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0

from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanATI


================================================
FILE: wan/configs/__init__.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from .wan_i2v_14B import i2v_14B
from .wan_t2v_1_3B import t2v_1_3B
from .wan_t2v_14B import t2v_14B

# the config of t2i_14B is the same as t2v_14B
t2i_14B = copy.deepcopy(t2v_14B)
t2i_14B.__name__ = 'Config: Wan T2I 14B'

# the config of flf2v_14B is the same as i2v_14B
flf2v_14B = copy.deepcopy(i2v_14B)
flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt

WAN_CONFIGS = {
    'ati-14B': i2v_14B,
}

SIZE_CONFIGS = {
    '720*1280': (720, 1280),
    '1280*720': (1280, 720),
    '480*832': (480, 832),
    '832*480': (832, 480),
    '1024*1024': (1024, 1024),
}

MAX_AREA_CONFIGS = {
    '720*1280': 720 * 1280,
    '1280*720': 1280 * 720,
    '480*832': 480 * 832,
    '832*480': 832 * 480,
}

SUPPORTED_SIZES = {
    'ati-14B': ('480*832', '832*480'),
}


================================================
FILE: wan/configs/shared_config.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict

#------------------------ Wan shared config ------------------------#
wan_shared_cfg = EasyDict()

# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.bfloat16
wan_shared_cfg.text_len = 512

# transformer
wan_shared_cfg.param_dtype = torch.bfloat16

# inference
wan_shared_cfg.num_train_timesteps = 1000
wan_shared_cfg.sample_fps = 16
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'


================================================
FILE: wan/configs/wan_i2v_14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
from easydict import EasyDict

from .shared_config import wan_shared_cfg

#------------------------ Wan I2V 14B ------------------------#

i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
i2v_14B.update(wan_shared_cfg)
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt

i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_14B.t5_tokenizer = 'google/umt5-xxl'

# clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
i2v_14B.clip_tokenizer = 'xlm-roberta-large'

# vae
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
i2v_14B.vae_stride = (4, 8, 8)

# transformer
i2v_14B.patch_size = (1, 2, 2)
i2v_14B.dim = 5120
i2v_14B.ffn_dim = 13824
i2v_14B.freq_dim = 256
i2v_14B.num_heads = 40
i2v_14B.num_layers = 40
i2v_14B.window_size = (-1, -1)
i2v_14B.qk_norm = True
i2v_14B.cross_attn_norm = True
i2v_14B.eps = 1e-6


================================================
FILE: wan/configs/wan_t2v_14B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict

from .shared_config import wan_shared_cfg

#------------------------ Wan T2V 14B ------------------------#

t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
t2v_14B.update(wan_shared_cfg)

# t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_14B.t5_tokenizer = 'google/umt5-xxl'

# vae
t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
t2v_14B.vae_stride = (4, 8, 8)

# transformer
t2v_14B.patch_size = (1, 2, 2)
t2v_14B.dim = 5120
t2v_14B.ffn_dim = 13824
t2v_14B.freq_dim = 256
t2v_14B.num_heads = 40
t2v_14B.num_layers = 40
t2v_14B.window_size = (-1, -1)
t2v_14B.qk_norm = True
t2v_14B.cross_attn_norm = True
t2v_14B.eps = 1e-6


================================================
FILE: wan/configs/wan_t2v_1_3B.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from easydict import EasyDict

from .shared_config import wan_shared_cfg

#------------------------ Wan T2V 1.3B ------------------------#

t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
t2v_1_3B.update(wan_shared_cfg)

# t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'

# vae
t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
t2v_1_3B.vae_stride = (4, 8, 8)

# transformer
t2v_1_3B.patch_size = (1, 2, 2)
t2v_1_3B.dim = 1536
t2v_1_3B.ffn_dim = 8960
t2v_1_3B.freq_dim = 256
t2v_1_3B.num_heads = 12
t2v_1_3B.num_layers = 30
t2v_1_3B.window_size = (-1, -1)
t2v_1_3B.qk_norm = True
t2v_1_3B.cross_attn_norm = True
t2v_1_3B.eps = 1e-6


================================================
FILE: wan/distributed/__init__.py
================================================


================================================
FILE: wan/distributed/fsdp.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
from functools import partial

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage


def shard_model(
    model,
    device_id,
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32,
    process_group=None,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    sync_module_states=True,
):
    model = FSDP(
        module=model,
        process_group=process_group,
        sharding_strategy=sharding_strategy,
        auto_wrap_policy=partial(
            lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
        mixed_precision=MixedPrecision(
            param_dtype=param_dtype,
            reduce_dtype=reduce_dtype,
            buffer_dtype=buffer_dtype),
        device_id=device_id,
        sync_module_states=sync_module_states)
    return model


def free_model(model):
    for m in model.modules():
        if isinstance(m, FSDP):
            _free_storage(m._handle.flat_param.data)
    del model
    gc.collect()
    torch.cuda.empty_cache()


================================================
FILE: wan/distributed/xdit_context_parallel.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (
    get_sequence_parallel_rank,
    get_sequence_parallel_world_size,
    get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention

from ..modules.model import sinusoidal_embedding_1d


def pad_freqs(original_tensor, target_len):
    seq_len, s1, s2 = original_tensor.shape
    pad_size = target_len - seq_len
    padding_tensor = torch.ones(
        pad_size,
        s1,
        s2,
        dtype=original_tensor.dtype,
        device=original_tensor.device)
    padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
    return padded_tensor


@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
    """
    x:          [B, L, N, C].
    grid_sizes: [B, 3].
    freqs:      [M, C // 2].
    """
    s, n, c = x.size(1), x.size(2), x.size(3) // 2
    # split freqs
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)

    # loop over samples
    output = []
    for i, (f, h, w) in enumerate(grid_sizes.tolist()):
        seq_len = f * h * w

        # precompute multipliers
        x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
            s, n, -1, 2))
        freqs_i = torch.cat([
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
        ],
                            dim=-1).reshape(seq_len, 1, -1)

        # apply rotary embedding
        sp_size = get_sequence_parallel_world_size()
        sp_rank = get_sequence_parallel_rank()
        freqs_i = pad_freqs(freqs_i, s * sp_size)
        s_per_rank = s
        freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
                                                       s_per_rank), :, :]
        x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
        x_i = torch.cat([x_i, x[i, s:]])

        # append to collection
        output.append(x_i)
    return torch.stack(output).float()


def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
    # embeddings
    c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
    c = [u.flatten(2).transpose(1, 2) for u in c]
    c = torch.cat([
        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
        for u in c
    ])

    # arguments
    new_kwargs = dict(x=x)
    new_kwargs.update(kwargs)

    # Context Parallel
    c = torch.chunk(
        c, get_sequence_parallel_world_size(),
        dim=1)[get_sequence_parallel_rank()]

    hints = []
    for block in self.vace_blocks:
        c, c_skip = block(c, **new_kwargs)
        hints.append(c_skip)
    return hints


def usp_dit_forward(
    self,
    x,
    t,
    context,
    seq_len,
    vace_context=None,
    vace_context_scale=1.0,
    clip_fea=None,
    y=None,
):
    """
    x:              A list of videos each with shape [C, T, H, W].
    t:              [B].
    context:        A list of text embeddings each with shape [L, C].
    """
    if self.model_type == 'i2v':
        assert clip_fea is not None and y is not None
    # params
    device = self.patch_embedding.weight.device
    if self.freqs.device != device:
        self.freqs = self.freqs.to(device)

    if self.model_type != 'vace' and y is not None:
        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]

    # embeddings
    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
    grid_sizes = torch.stack(
        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
    x = [u.flatten(2).transpose(1, 2) for u in x]
    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
    assert seq_lens.max() <= seq_len
    x = torch.cat([
        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
        for u in x
    ])

    # time embeddings
    with amp.autocast(dtype=torch.float32):
        e = self.time_embedding(
            sinusoidal_embedding_1d(self.freq_dim, t).float())
        e0 = self.time_projection(e).unflatten(1, (6, self.dim))
        assert e.dtype == torch.float32 and e0.dtype == torch.float32

    # context
    context_lens = None
    context = self.text_embedding(
        torch.stack([
            torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
            for u in context
        ]))

    if self.model_type != 'vace' and clip_fea is not None:
        context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
        context = torch.concat([context_clip, context], dim=1)

    # arguments
    kwargs = dict(
        e=e0,
        seq_lens=seq_lens,
        grid_sizes=grid_sizes,
        freqs=self.freqs,
        context=context,
        context_lens=context_lens)

    # Context Parallel
    x = torch.chunk(
        x, get_sequence_parallel_world_size(),
        dim=1)[get_sequence_parallel_rank()]

    if self.model_type == 'vace':
        hints = self.forward_vace(x, vace_context, seq_len, kwargs)
        kwargs['hints'] = hints
        kwargs['context_scale'] = vace_context_scale

    for block in self.blocks:
        x = block(x, **kwargs)

    # head
    x = self.head(x, e)

    # Context Parallel
    x = get_sp_group().all_gather(x, dim=1)

    # unpatchify
    x = self.unpatchify(x, grid_sizes)
    return [u.float() for u in x]


def usp_attn_forward(self,
                     x,
                     seq_lens,
                     grid_sizes,
                     freqs,
                     dtype=torch.bfloat16):
    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
    half_dtypes = (torch.float16, torch.bfloat16)

    def half(x):
        return x if x.dtype in half_dtypes else x.to(dtype)

    # query, key, value function
    def qkv_fn(x):
        q = self.norm_q(self.q(x)).view(b, s, n, d)
        k = self.norm_k(self.k(x)).view(b, s, n, d)
        v = self.v(x).view(b, s, n, d)
        return q, k, v

    q, k, v = qkv_fn(x)
    q = rope_apply(q, grid_sizes, freqs)
    k = rope_apply(k, grid_sizes, freqs)

    # TODO: We should use unpaded q,k,v for attention.
    # k_lens = seq_lens // get_sequence_parallel_world_size()
    # if k_lens is not None:
    #     q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
    #     k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
    #     v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)

    x = xFuserLongContextAttention()(
        None,
        query=half(q),
        key=half(k),
        value=half(v),
        window_size=self.window_size)

    # TODO: padding after attention.
    # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)

    # output
    x = x.flatten(2)
    x = self.o(x)
    return x


================================================
FILE: wan/first_last_frame2video.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial

import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm

from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
    FlowDPMSolverMultistepScheduler,
    get_sampling_sigmas,
    retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler


class WanFLF2V:

    def __init__(
        self,
        config,
        checkpoint_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
        t5_cpu=False,
        init_on_cpu=True,
    ):
        r"""
        Initializes the image-to-video generation model components.

        Args:
            config (EasyDict):
                Object containing model parameters initialized from config.py
            checkpoint_dir (`str`):
                Path to directory containing model checkpoints
            device_id (`int`,  *optional*, defaults to 0):
                Id of target GPU device
            rank (`int`,  *optional*, defaults to 0):
                Process rank for distributed training
            t5_fsdp (`bool`, *optional*, defaults to False):
                Enable FSDP sharding for T5 model
            dit_fsdp (`bool`, *optional*, defaults to False):
                Enable FSDP sharding for DiT model
            use_usp (`bool`, *optional*, defaults to False):
                Enable distribution strategy of USP.
            t5_cpu (`bool`, *optional*, defaults to False):
                Whether to place T5 model on CPU. Only works without t5_fsdp.
            init_on_cpu (`bool`, *optional*, defaults to True):
                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
        """
        self.device = torch.device(f"cuda:{device_id}")
        self.config = config
        self.rank = rank
        self.use_usp = use_usp
        self.t5_cpu = t5_cpu

        self.num_train_timesteps = config.num_train_timesteps
        self.param_dtype = config.param_dtype

        shard_fn = partial(shard_model, device_id=device_id)
        self.text_encoder = T5EncoderModel(
            text_len=config.text_len,
            dtype=config.t5_dtype,
            device=torch.device('cpu'),
            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
            shard_fn=shard_fn if t5_fsdp else None,
        )

        self.vae_stride = config.vae_stride
        self.patch_size = config.patch_size
        self.vae = WanVAE(
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
            device=self.device)

        self.clip = CLIPModel(
            dtype=config.clip_dtype,
            device=self.device,
            checkpoint_path=os.path.join(checkpoint_dir,
                                         config.clip_checkpoint),
            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))

        logging.info(f"Creating WanModel from {checkpoint_dir}")
        self.model = WanModel.from_pretrained(checkpoint_dir)
        self.model.eval().requires_grad_(False)

        if t5_fsdp or dit_fsdp or use_usp:
            init_on_cpu = False

        if use_usp:
            from xfuser.core.distributed import get_sequence_parallel_world_size

            from .distributed.xdit_context_parallel import (
                usp_attn_forward,
                usp_dit_forward,
            )
            for block in self.model.blocks:
                block.self_attn.forward = types.MethodType(
                    usp_attn_forward, block.self_attn)
            self.model.forward = types.MethodType(usp_dit_forward, self.model)
            self.sp_size = get_sequence_parallel_world_size()
        else:
            self.sp_size = 1

        if dist.is_initialized():
            dist.barrier()
        if dit_fsdp:
            self.model = shard_fn(self.model)
        else:
            if not init_on_cpu:
                self.model.to(self.device)

        self.sample_neg_prompt = config.sample_neg_prompt

    def generate(self,
                 input_prompt,
                 first_frame,
                 last_frame,
                 max_area=720 * 1280,
                 frame_num=81,
                 shift=16,
                 sample_solver='unipc',
                 sampling_steps=50,
                 guide_scale=5.5,
                 n_prompt="",
                 seed=-1,
                 offload_model=True):
        r"""
        Generates video frames from input first-last frame and text prompt using diffusion process.

        Args:
            input_prompt (`str`):
                Text prompt for content generation.
            first_frame (PIL.Image.Image):
                Input image tensor. Shape: [3, H, W]
            last_frame (PIL.Image.Image):
                Input image tensor. Shape: [3, H, W]
                [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
                to match first_frame.
            max_area (`int`, *optional*, defaults to 720*1280):
                Maximum pixel area for latent space calculation. Controls video resolution scaling
            frame_num (`int`, *optional*, defaults to 81):
                How many frames to sample from a video. The number should be 4n+1
            shift (`float`, *optional*, defaults to 5.0):
                Noise schedule shift parameter. Affects temporal dynamics
                [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
            sample_solver (`str`, *optional*, defaults to 'unipc'):
                Solver used to sample the video.
            sampling_steps (`int`, *optional*, defaults to 40):
                Number of diffusion sampling steps. Higher values improve quality but slow generation
            guide_scale (`float`, *optional*, defaults 5.0):
                Classifier-free guidance scale. Controls prompt adherence vs. creativity
            n_prompt (`str`, *optional*, defaults to ""):
                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
            seed (`int`, *optional*, defaults to -1):
                Random seed for noise generation. If -1, use random seed
            offload_model (`bool`, *optional*, defaults to True):
                If True, offloads models to CPU during generation to save VRAM

        Returns:
            torch.Tensor:
                Generated video frames tensor. Dimensions: (C, N H, W) where:
                - C: Color channels (3 for RGB)
                - N: Number of frames (81)
                - H: Frame height (from max_area)
                - W: Frame width from max_area)
        """
        first_frame_size = first_frame.size
        last_frame_size = last_frame.size
        first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
            self.device)
        last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
            self.device)

        F = frame_num
        first_frame_h, first_frame_w = first_frame.shape[1:]
        aspect_ratio = first_frame_h / first_frame_w
        lat_h = round(
            np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
            self.patch_size[1] * self.patch_size[1])
        lat_w = round(
            np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
            self.patch_size[2] * self.patch_size[2])
        first_frame_h = lat_h * self.vae_stride[1]
        first_frame_w = lat_w * self.vae_stride[2]
        if first_frame_size != last_frame_size:
            # 1. resize
            last_frame_resize_ratio = max(
                first_frame_size[0] / last_frame_size[0],
                first_frame_size[1] / last_frame_size[1])
            last_frame_size = [
                round(last_frame_size[0] * last_frame_resize_ratio),
                round(last_frame_size[1] * last_frame_resize_ratio),
            ]
            # 2. center crop
            last_frame = TF.center_crop(last_frame, last_frame_size)

        max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
            self.patch_size[1] * self.patch_size[2])
        max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size

        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
        seed_g = torch.Generator(device=self.device)
        seed_g.manual_seed(seed)
        noise = torch.randn(
            16, (F - 1) // 4 + 1,
            lat_h,
            lat_w,
            dtype=torch.float32,
            generator=seed_g,
            device=self.device)

        msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
        msk[:, 1:-1] = 0
        msk = torch.concat([
            torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
        ],
                           dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]

        if n_prompt == "":
            n_prompt = self.sample_neg_prompt

        # preprocess
        if not self.t5_cpu:
            self.text_encoder.model.to(self.device)
            context = self.text_encoder([input_prompt], self.device)
            context_null = self.text_encoder([n_prompt], self.device)
            if offload_model:
                self.text_encoder.model.cpu()
        else:
            context = self.text_encoder([input_prompt], torch.device('cpu'))
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
            context = [t.to(self.device) for t in context]
            context_null = [t.to(self.device) for t in context_null]

        self.clip.model.to(self.device)
        clip_context = self.clip.visual(
            [first_frame[:, None, :, :], last_frame[:, None, :, :]])
        if offload_model:
            self.clip.model.cpu()

        y = self.vae.encode([
            torch.concat([
                torch.nn.functional.interpolate(
                    first_frame[None].cpu(),
                    size=(first_frame_h, first_frame_w),
                    mode='bicubic').transpose(0, 1),
                torch.zeros(3, F - 2, first_frame_h, first_frame_w),
                torch.nn.functional.interpolate(
                    last_frame[None].cpu(),
                    size=(first_frame_h, first_frame_w),
                    mode='bicubic').transpose(0, 1),
            ],
                         dim=1).to(self.device)
        ])[0]
        y = torch.concat([msk, y])

        @contextmanager
        def noop_no_sync():
            yield

        no_sync = getattr(self.model, 'no_sync', noop_no_sync)

        # evaluation mode
        with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():

            if sample_solver == 'unipc':
                sample_scheduler = FlowUniPCMultistepScheduler(
                    num_train_timesteps=self.num_train_timesteps,
                    shift=1,
                    use_dynamic_shifting=False)
                sample_scheduler.set_timesteps(
                    sampling_steps, device=self.device, shift=shift)
                timesteps = sample_scheduler.timesteps
            elif sample_solver == 'dpm++':
                sample_scheduler = FlowDPMSolverMultistepScheduler(
                    num_train_timesteps=self.num_train_timesteps,
                    shift=1,
                    use_dynamic_shifting=False)
                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
                timesteps, _ = retrieve_timesteps(
                    sample_scheduler,
                    device=self.device,
                    sigmas=sampling_sigmas)
            else:
                raise NotImplementedError("Unsupported solver.")

            # sample videos
            latent = noise

            arg_c = {
                'context': [context[0]],
                'clip_fea': clip_context,
                'seq_len': max_seq_len,
                'y': [y],
            }

            arg_null = {
                'context': context_null,
                'clip_fea': clip_context,
                'seq_len': max_seq_len,
                'y': [y],
            }

            if offload_model:
                torch.cuda.empty_cache()

            self.model.to(self.device)
            for _, t in enumerate(tqdm(timesteps)):
                latent_model_input = [latent.to(self.device)]
                timestep = [t]

                timestep = torch.stack(timestep).to(self.device)

                noise_pred_cond = self.model(
                    latent_model_input, t=timestep, **arg_c)[0].to(
                        torch.device('cpu') if offload_model else self.device)
                if offload_model:
                    torch.cuda.empty_cache()
                noise_pred_uncond = self.model(
                    latent_model_input, t=timestep, **arg_null)[0].to(
                        torch.device('cpu') if offload_model else self.device)
                if offload_model:
                    torch.cuda.empty_cache()
                noise_pred = noise_pred_uncond + guide_scale * (
                    noise_pred_cond - noise_pred_uncond)

                latent = latent.to(
                    torch.device('cpu') if offload_model else self.device)

                temp_x0 = sample_scheduler.step(
                    noise_pred.unsqueeze(0),
                    t,
                    latent.unsqueeze(0),
                    return_dict=False,
                    generator=seed_g)[0]
                latent = temp_x0.squeeze(0)

                x0 = [latent.to(self.device)]
                del latent_model_input, timestep

            if offload_model:
                self.model.cpu()
                torch.cuda.empty_cache()

            if self.rank == 0:
                videos = self.vae.decode(x0)

        del noise, latent
        del sample_scheduler
        if offload_model:
            gc.collect()
            torch.cuda.synchronize()
        if dist.is_initialized():
            dist.barrier()

        return videos[0] if self.rank == 0 else None


================================================
FILE: wan/image2video.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial

import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm

from .distributed.fsdp import shard_model
from .modules.motion_patch import patch_motion
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
    FlowDPMSolverMultistepScheduler,
    get_sampling_sigmas,
    retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler


class WanATI:
    def __init__(
        self,
        config,
        checkpoint_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
        t5_cpu=False,
        init_on_cpu=True,
    ):
        r"""
        Initializes the image-to-video generation model components.

        Args:
            config (EasyDict):
                Object containing model parameters initialized from config.py
            checkpoint_dir (`str`):
                Path to directory containing model checkpoints
            device_id (`int`,  *optional*, defaults to 0):
                Id of target GPU device
            rank (`int`,  *optional*, defaults to 0):
                Process rank for distributed training
            t5_fsdp (`bool`, *optional*, defaults to False):
                Enable FSDP sharding for T5 model
            dit_fsdp (`bool`, *optional*, defaults to False):
                Enable FSDP sharding for DiT model
            use_usp (`bool`, *optional*, defaults to False):
                Enable distribution strategy of USP.
            t5_cpu (`bool`, *optional*, defaults to False):
                Whether to place T5 model on CPU. Only works without t5_fsdp.
            init_on_cpu (`bool`, *optional*, defaults to True):
                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
        """
        self.device = torch.device(f"cuda:{device_id}")
        self.config = config
        self.rank = rank
        self.use_usp = use_usp
        self.t5_cpu = t5_cpu

        self.num_train_timesteps = config.num_train_timesteps
        self.param_dtype = config.param_dtype

        shard_fn = partial(shard_model, device_id=device_id)
        self.text_encoder = T5EncoderModel(
            text_len=config.text_len,
            dtype=config.t5_dtype,
            device=torch.device('cpu'),
            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
            shard_fn=shard_fn if t5_fsdp else None,
        )

        self.vae_stride = config.vae_stride
        self.patch_size = config.patch_size
        self.vae = WanVAE(
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
            device=self.device)

        self.clip = CLIPModel(
            dtype=config.clip_dtype,
            device=self.device,
            checkpoint_path=os.path.join(checkpoint_dir,
                                         config.clip_checkpoint),
            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))

        logging.info(f"Creating WanModel from {checkpoint_dir}")
        self.model = WanModel.from_pretrained(checkpoint_dir)
        self.model.eval().requires_grad_(False)

        if t5_fsdp or dit_fsdp or use_usp:
            init_on_cpu = False

        if use_usp:
            from xfuser.core.distributed import get_sequence_parallel_world_size

            from .distributed.xdit_context_parallel import (
                usp_attn_forward,
                usp_dit_forward,
            )
            for block in self.model.blocks:
                block.self_attn.forward = types.MethodType(
                    usp_attn_forward, block.self_attn)
            self.model.forward = types.MethodType(usp_dit_forward, self.model)
            self.sp_size = get_sequence_parallel_world_size()
        else:
            self.sp_size = 1

        if dist.is_initialized():
            dist.barrier()
        if dit_fsdp:
            self.model = shard_fn(self.model)
        else:
            if not init_on_cpu:
                self.model.to(self.device)

        self.sample_neg_prompt = config.sample_neg_prompt

    def generate(self,
                 input_prompt,
                 img,
                 tracks,
                 max_area=720 * 1280,
                 frame_num=81,
                 shift=5.0,
                 sample_solver='unipc',
                 sampling_steps=40,
                 guide_scale=5.0,
                 n_prompt="",
                 seed=-1,
                 offload_model=True):
        r"""
        Generates video frames from input image and text prompt using diffusion process.

        Args:
            input_prompt (`str`):
                Text prompt for content generation.
            img (PIL.Image.Image):
                Input image tensor. Shape: [3, H, W]
            max_area (`int`, *optional*, defaults to 720*1280):
                Maximum pixel area for latent space calculation. Controls video resolution scaling
            frame_num (`int`, *optional*, defaults to 81):
                How many frames to sample from a video. The number should be 4n+1
            shift (`float`, *optional*, defaults to 5.0):
                Noise schedule shift parameter. Affects temporal dynamics
                [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
            sample_solver (`str`, *optional*, defaults to 'unipc'):
                Solver used to sample the video.
            sampling_steps (`int`, *optional*, defaults to 40):
                Number of diffusion sampling steps. Higher values improve quality but slow generation
            guide_scale (`float`, *optional*, defaults 5.0):
                Classifier-free guidance scale. Controls prompt adherence vs. creativity
            n_prompt (`str`, *optional*, defaults to ""):
                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
            seed (`int`, *optional*, defaults to -1):
                Random seed for noise generation. If -1, use random seed
            offload_model (`bool`, *optional*, defaults to True):
                If True, offloads models to CPU during generation to save VRAM

        Returns:
            torch.Tensor:
                Generated video frames tensor. Dimensions: (C, N H, W) where:
                - C: Color channels (3 for RGB)
                - N: Number of frames (81)
                - H: Frame height (from max_area)
                - W: Frame width from max_area)
        """
        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
        tracks = tracks.to(self.device)[None]

        F = frame_num
        h, w = img.shape[1:]
        aspect_ratio = h / w
        lat_h = round(
            np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
            self.patch_size[1] * self.patch_size[1])
        lat_w = round(
            np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
            self.patch_size[2] * self.patch_size[2])
        h = lat_h * self.vae_stride[1]
        w = lat_w * self.vae_stride[2]

        max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
            self.patch_size[1] * self.patch_size[2])
        max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size

        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
        seed_g = torch.Generator(device=self.device)
        seed_g.manual_seed(seed)
        noise = torch.randn(
            16, (F - 1) // 4 + 1,
            lat_h,
            lat_w,
            dtype=torch.float32,
            generator=seed_g,
            device=self.device)

        msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
        msk[:, 1:] = 0
        msk = torch.concat([
            torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
        ],
            dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)[0]

        if n_prompt == "":
            n_prompt = self.sample_neg_prompt

        # preprocess
        if not self.t5_cpu:
            self.text_encoder.model.to(self.device)
            context = self.text_encoder([input_prompt], self.device)
            context_null = self.text_encoder([n_prompt], self.device)
            if offload_model:
                self.text_encoder.model.cpu()
        else:
            context = self.text_encoder([input_prompt], torch.device('cpu'))
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
            context = [t.to(self.device) for t in context]
            context_null = [t.to(self.device) for t in context_null]

        self.clip.model.to(self.device)
        clip_context = self.clip.visual([img[:, None, :, :]])
        if offload_model:
            self.clip.model.cpu()

        y = self.vae.encode([
            torch.concat([
                torch.nn.functional.interpolate(
                    img[None].cpu(), size=(h, w), mode='bicubic').transpose(
                        0, 1),
                torch.zeros(3, F - 1, h, w)
            ],
                dim=1).to(self.device)
        ])[0]
        y = torch.concat([msk, y])

        with torch.no_grad():
            y = patch_motion(tracks.type(y.dtype), y, training=False)

        @contextmanager
        def noop_no_sync():
            yield

        no_sync = getattr(self.model, 'no_sync', noop_no_sync)

        # evaluation mode
        with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():

            if sample_solver == 'unipc':
                sample_scheduler = FlowUniPCMultistepScheduler(
                    num_train_timesteps=self.num_train_timesteps,
                    shift=1,
                    use_dynamic_shifting=False)
                sample_scheduler.set_timesteps(
                    sampling_steps, device=self.device, shift=shift)
                timesteps = sample_scheduler.timesteps
            elif sample_solver == 'dpm++':
                sample_scheduler = FlowDPMSolverMultistepScheduler(
                    num_train_timesteps=self.num_train_timesteps,
                    shift=1,
                    use_dynamic_shifting=False)
                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
                timesteps, _ = retrieve_timesteps(
                    sample_scheduler,
                    device=self.device,
                    sigmas=sampling_sigmas)
            else:
                raise NotImplementedError("Unsupported solver.")

            # sample videos
            latent = noise

            arg_c = {
                'context': [context[0]],
                'clip_fea': clip_context,
                'seq_len': max_seq_len,
                'y': [y],
            }

            arg_null = {
                'context': context_null,
                'clip_fea': clip_context,
                'seq_len': max_seq_len,
                'y': [y],
            }

            if offload_model:
                torch.cuda.empty_cache()

            self.model.to(self.device)
            for _, t in enumerate(tqdm(timesteps)):
                latent_model_input = [latent.to(self.device)]
                timestep = [t]

                timestep = torch.stack(timestep).to(self.device)

                noise_pred_cond = self.model(
                    latent_model_input, t=timestep, **arg_c)[0].to(
                        torch.device('cpu') if offload_model else self.device)
                if offload_model:
                    torch.cuda.empty_cache()
                noise_pred_uncond = self.model(
                    latent_model_input, t=timestep, **arg_null)[0].to(
                        torch.device('cpu') if offload_model else self.device)
                if offload_model:
                    torch.cuda.empty_cache()
                noise_pred = noise_pred_uncond + guide_scale * (
                    noise_pred_cond - noise_pred_uncond)

                latent = latent.to(
                    torch.device('cpu') if offload_model else self.device)

                temp_x0 = sample_scheduler.step(
                    noise_pred.unsqueeze(0),
                    t,
                    latent.unsqueeze(0),
                    return_dict=False,
                    generator=seed_g)[0]
                latent = temp_x0.squeeze(0)

                x0 = [latent.to(self.device)]
                del latent_model_input, timestep

            if offload_model:
                self.model.cpu()
                torch.cuda.empty_cache()

            if self.rank == 0:
                videos = self.vae.decode(x0)

        del noise, latent
        del sample_scheduler
        if offload_model:
            gc.collect()
            torch.cuda.synchronize()
        if dist.is_initialized():
            dist.barrier()

        return videos[0] if self.rank == 0 else None


================================================
FILE: wan/modules/__init__.py
================================================
from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vace_model import VaceWanModel
from .vae import WanVAE

__all__ = [
    'WanVAE',
    'WanModel',
    'VaceWanModel',
    'T5Model',
    'T5Encoder',
    'T5Decoder',
    'T5EncoderModel',
    'HuggingfaceTokenizer',
    'flash_attention',
]


================================================
FILE: wan/modules/attention.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch

try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False

if not FLASH_ATTN_3_AVAILABLE:
    try:
        import flash_attn_hopper as flash_attn_interface
        FLASH_ATTN_3_AVAILABLE = True
    except ModuleNotFoundError:
        FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

import warnings

__all__ = [
    'flash_attention',
    'attention',
]


def flash_attention(
    q,
    k,
    v,
    q_lens=None,
    k_lens=None,
    dropout_p=0.,
    softmax_scale=None,
    q_scale=None,
    causal=False,
    window_size=(-1, -1),
    deterministic=False,
    dtype=torch.bfloat16,
    version=None,
):
    """
    q:              [B, Lq, Nq, C1].
    k:              [B, Lk, Nk, C1].
    v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.
    q_lens:         [B].
    k_lens:         [B].
    dropout_p:      float. Dropout probability.
    softmax_scale:  float. The scaling of QK^T before applying softmax.
    causal:         bool. Whether to apply causal attention mask.
    window_size:    (left right). If not (-1, -1), apply sliding window local attention.
    deterministic:  bool. If True, slightly slower and uses more memory.
    dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
    """
    half_dtypes = (torch.float16, torch.bfloat16)
    assert dtype in half_dtypes
    assert q.device.type == 'cuda' and q.size(-1) <= 256

    # params
    b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype

    def half(x):
        return x if x.dtype in half_dtypes else x.to(dtype)

    # preprocess query
    if q_lens is None:
        q = half(q.flatten(0, 1))
        q_lens = torch.tensor(
            [lq] * b, dtype=torch.int32).to(
                device=q.device, non_blocking=True)
    else:
        q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))

    # preprocess key, value
    if k_lens is None:
        k = half(k.flatten(0, 1))
        v = half(v.flatten(0, 1))
        k_lens = torch.tensor(
            [lk] * b, dtype=torch.int32).to(
                device=k.device, non_blocking=True)
    else:
        k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
        v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))

    q = q.to(v.dtype)
    k = k.to(v.dtype)

    if q_scale is not None:
        q = q * q_scale

    if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
        warnings.warn(
            'Flash attention 3 is not available, use flash attention 2 instead.'
        )

    # apply attention
    if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
        # Note: dropout_p, window_size are not supported in FA3 now.
        x = flash_attn_interface.flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
           
Download .txt
gitextract_zgae_61m/

├── .gitignore
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── examples/
│   ├── test.yaml
│   └── tracks/
│       ├── bear.pth
│       ├── deco.pth
│       ├── fish.pth
│       ├── giraffe.pth
│       ├── human.pth
│       └── sea.pth
├── generate.py
├── gradio/
│   ├── fl2v_14B_singleGPU.py
│   ├── i2v_14B_singleGPU.py
│   ├── t2i_14B_singleGPU.py
│   ├── t2v_1.3B_singleGPU.py
│   ├── t2v_14B_singleGPU.py
│   └── vace.py
├── pyproject.toml
├── requirements.txt
├── run_example.sh
├── tests/
│   ├── README.md
│   └── test.sh
├── tools/
│   ├── get_track_from_videos.py
│   ├── plot_user_inputs.py
│   ├── trajectory_editor/
│   │   ├── app.py
│   │   └── templates/
│   │       └── index.html
│   └── visualize_trajectory.py
└── wan/
    ├── __init__.py
    ├── configs/
    │   ├── __init__.py
    │   ├── shared_config.py
    │   ├── wan_i2v_14B.py
    │   ├── wan_t2v_14B.py
    │   └── wan_t2v_1_3B.py
    ├── distributed/
    │   ├── __init__.py
    │   ├── fsdp.py
    │   └── xdit_context_parallel.py
    ├── first_last_frame2video.py
    ├── image2video.py
    ├── modules/
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── clip.py
    │   ├── model.py
    │   ├── motion_patch.py
    │   ├── t5.py
    │   ├── tokenizers.py
    │   ├── vace_model.py
    │   ├── vae.py
    │   └── xlm_roberta.py
    ├── utils/
    │   ├── __init__.py
    │   ├── fm_solvers.py
    │   ├── fm_solvers_unipc.py
    │   ├── motion.py
    │   ├── prompt_extend.py
    │   ├── qwen_vl_utils.py
    │   ├── utils.py
    │   └── vace_processor.py
    └── vace.py
Download .txt
SYMBOL INDEX (360 symbols across 32 files)

FILE: generate.py
  function _validate_args (line 25) | def _validate_args(args):
  function _parse_args (line 55) | def _parse_args():
  function _init_logging (line 215) | def _init_logging(rank):
  function generate (line 227) | def generate(args):

FILE: gradio/fl2v_14B_singleGPU.py
  function load_model (line 27) | def load_model(value):
  function prompt_enc (line 59) | def prompt_enc(prompt, img_first, img_last, tar_lang):
  function flf2v_generation (line 73) | def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_...
  function gradio_interface (line 113) | def gradio_interface():
  function _parse_args (line 212) | def _parse_args():

FILE: gradio/i2v_14B_singleGPU.py
  function load_model (line 28) | def load_model(value):
  function prompt_enc (line 87) | def prompt_enc(prompt, img, tar_lang):
  function i2v_generation (line 101) | def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
  function gradio_interface (line 149) | def gradio_interface():
  function _parse_args (line 241) | def _parse_args():

FILE: gradio/t2i_14B_singleGPU.py
  function prompt_enc (line 26) | def prompt_enc(prompt, tar_lang):
  function t2i_generation (line 35) | def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
  function gradio_interface (line 64) | def gradio_interface():
  function _parse_args (line 152) | def _parse_args():

FILE: gradio/t2v_1.3B_singleGPU.py
  function prompt_enc (line 26) | def prompt_enc(prompt, tar_lang):
  function t2v_generation (line 35) | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
  function gradio_interface (line 64) | def gradio_interface():
  function _parse_args (line 154) | def _parse_args():

FILE: gradio/t2v_14B_singleGPU.py
  function prompt_enc (line 26) | def prompt_enc(prompt, tar_lang):
  function t2v_generation (line 35) | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
  function gradio_interface (line 64) | def gradio_interface():
  function _parse_args (line 152) | def _parse_args():

FILE: gradio/vace.py
  class FixedSizeQueue (line 22) | class FixedSizeQueue:
    method __init__ (line 24) | def __init__(self, max_size):
    method add (line 28) | def add(self, item):
    method get (line 33) | def get(self):
    method __repr__ (line 36) | def __repr__(self):
  class VACEInference (line 40) | class VACEInference:
    method __init__ (line 42) | def __init__(self,
    method create_ui (line 70) | def create_ui(self, *args, **kwargs):
    method generate (line 211) | def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
    method set_callbacks (line 267) | def set_callbacks(self, **kwargs):

FILE: tools/get_track_from_videos.py
  function array_to_npz_bytes (line 12) | def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MUL...
  function parse_to_list (line 23) | def parse_to_list(text: str) -> List[List[int]]:
  function load_video_to_frames (line 42) | def load_video_to_frames(
  function sample_grid_points (line 108) | def sample_grid_points(bbox, N):
  function resize_images_to_size (line 141) | def resize_images_to_size(image_list, size=1024):
  function resize_box (line 157) | def resize_box(box, ratios):
  class TrackAnyPoint (line 161) | class TrackAnyPoint():
    method __init__ (line 162) | def __init__(self, n_points=60):
    method __call__ (line 170) | def __call__(self, video_frames: List[Image.Image]):
    method inference (line 186) | def inference(self, frames: np.ndarray, w_ori, h_ori, tracks) -> np.nd...
  function convert_grid_coordinates (line 199) | def convert_grid_coordinates(
  function save_frames_to_mp4 (line 256) | def save_frames_to_mp4(frames, output_path, fps=24, codec='mp4v'):
  function save_yaml (line 291) | def save_yaml(

FILE: tools/plot_user_inputs.py
  function plot_tracks (line 24) | def plot_tracks(
  function unzip_to_array (line 125) | def unzip_to_array(
  function main (line 142) | def main():

FILE: tools/trajectory_editor/app.py
  function array_to_npz_bytes (line 55) | def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MUL...
  function load_existing_tracks (line 66) | def load_existing_tracks(path):
  function index (line 76) | def index():
  function upload_image (line 81) | def upload_image():
  function store_tracks (line 103) | def store_tracks():
  function ensure_localhost (line 209) | def ensure_localhost():

FILE: tools/visualize_trajectory.py
  function unzip_to_array (line 30) | def unzip_to_array(
  function get_colors (line 48) | def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
  function age_to_bgr (line 63) | def age_to_bgr(ratio: float) -> Tuple[int,int,int]:
  function paint_point_track (line 89) | def paint_point_track(

FILE: wan/distributed/fsdp.py
  function shard_model (line 12) | def shard_model(
  function free_model (line 37) | def free_model(model):

FILE: wan/distributed/xdit_context_parallel.py
  function pad_freqs (line 14) | def pad_freqs(original_tensor, target_len):
  function rope_apply (line 28) | def rope_apply(x, grid_sizes, freqs):
  function usp_dit_forward_vace (line 68) | def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
  function usp_dit_forward (line 93) | def usp_dit_forward(
  function usp_attn_forward (line 183) | def usp_attn_forward(self,

FILE: wan/first_last_frame2video.py
  class WanFLF2V (line 32) | class WanFLF2V:
    method __init__ (line 34) | def __init__(
    method generate (line 133) | def generate(self,

FILE: wan/image2video.py
  class WanATI (line 35) | class WanATI:
    method __init__ (line 36) | def __init__(
    method generate (line 135) | def generate(self,

FILE: wan/modules/attention.py
  function flash_attention (line 31) | def flash_attention(
  function attention (line 140) | def attention(

FILE: wan/modules/clip.py
  function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
  class QuickGELU (line 41) | class QuickGELU(nn.Module):
    method forward (line 43) | def forward(self, x):
  class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
    method forward (line 49) | def forward(self, x):
  class SelfAttention (line 53) | class SelfAttention(nn.Module):
    method __init__ (line 55) | def __init__(self,
    method forward (line 74) | def forward(self, x):
  class SwiGLU (line 94) | class SwiGLU(nn.Module):
    method __init__ (line 96) | def __init__(self, dim, mid_dim):
    method forward (line 106) | def forward(self, x):
  class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
    method __init__ (line 114) | def __init__(self,
    method forward (line 146) | def forward(self, x):
  class AttentionPool (line 156) | class AttentionPool(nn.Module):
    method __init__ (line 158) | def __init__(self,
    method forward (line 186) | def forward(self, x):
  class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
    method __init__ (line 211) | def __init__(self,
    method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
  class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
    method __init__ (line 305) | def __init__(self, **kwargs):
    method forward (line 315) | def forward(self, ids):
  class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
    method __init__ (line 330) | def __init__(self,
    method forward (line 406) | def forward(self, imgs, txt_ids):
    method param_groups (line 418) | def param_groups(self):
  function _clip (line 434) | def _clip(pretrained=False,
  function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
  class CLIPModel (line 501) | class CLIPModel:
    method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
    method visual (line 527) | def visual(self, videos):

FILE: wan/modules/model.py
  function sinusoidal_embedding_1d (line 18) | def sinusoidal_embedding_1d(dim, position):
  function rope_params (line 32) | def rope_params(max_seq_len, dim, theta=10000):
  function rope_apply (line 43) | def rope_apply(x, grid_sizes, freqs):
  class WanRMSNorm (line 73) | class WanRMSNorm(nn.Module):
    method __init__ (line 75) | def __init__(self, dim, eps=1e-5):
    method forward (line 81) | def forward(self, x):
    method _norm (line 88) | def _norm(self, x):
  class WanLayerNorm (line 92) | class WanLayerNorm(nn.LayerNorm):
    method __init__ (line 94) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
    method forward (line 97) | def forward(self, x):
  class WanSelfAttention (line 105) | class WanSelfAttention(nn.Module):
    method __init__ (line 107) | def __init__(self,
    method forward (line 130) | def forward(self, x, seq_lens, grid_sizes, freqs):
  class WanT2VCrossAttention (line 162) | class WanT2VCrossAttention(WanSelfAttention):
    method forward (line 164) | def forward(self, x, context, context_lens):
  class WanI2VCrossAttention (line 187) | class WanI2VCrossAttention(WanSelfAttention):
    method __init__ (line 189) | def __init__(self,
    method forward (line 202) | def forward(self, x, context, context_lens):
  class WanAttentionBlock (line 238) | class WanAttentionBlock(nn.Module):
    method __init__ (line 240) | def __init__(self,
    method forward (line 278) | def forward(
  class Head (line 320) | class Head(nn.Module):
    method __init__ (line 322) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
    method forward (line 337) | def forward(self, x, e):
  class MLPProj (line 350) | class MLPProj(torch.nn.Module):
    method __init__ (line 352) | def __init__(self, in_dim, out_dim, flf_pos_emb=False):
    method forward (line 363) | def forward(self, image_embeds):
  class WanModel (line 372) | class WanModel(ModelMixin, ConfigMixin):
    method __init__ (line 383) | def __init__(self,
    method forward (line 493) | def forward(
    method unpatchify (line 584) | def unpatchify(self, x, grid_sizes):
    method init_weights (line 609) | def init_weights(self):

FILE: wan/modules/motion_patch.py
  function ind_sel (line 20) | def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
  function merge_final (line 51) | def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assi...
  function patch_motion (line 77) | def patch_motion(

FILE: wan/modules/t5.py
  function fp16_clamp (line 20) | def fp16_clamp(x):
  function init_weights (line 27) | def init_weights(m):
  class GELU (line 46) | class GELU(nn.Module):
    method forward (line 48) | def forward(self, x):
  class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
    method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
    method forward (line 61) | def forward(self, x):
  class T5Attention (line 69) | class T5Attention(nn.Module):
    method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
    method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
  class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
    method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
    method forward (line 136) | def forward(self, x):
  class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
    method __init__ (line 146) | def __init__(self,
    method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
  class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
    method __init__ (line 180) | def __init__(self,
    method forward (line 206) | def forward(self,
  class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
    method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
    method forward (line 233) | def forward(self, lq, lk):
    method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
  class T5Encoder (line 267) | class T5Encoder(nn.Module):
    method __init__ (line 269) | def __init__(self,
    method forward (line 303) | def forward(self, ids, mask=None):
  class T5Decoder (line 315) | class T5Decoder(nn.Module):
    method __init__ (line 317) | def __init__(self,
    method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
  class T5Model (line 372) | class T5Model(nn.Module):
    method __init__ (line 374) | def __init__(self,
    method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
  function _t5 (line 415) | def _t5(name,
  function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
  class T5EncoderModel (line 472) | class T5EncoderModel:
    method __init__ (line 474) | def __init__(
    method __call__ (line 506) | def __call__(self, texts, device):

FILE: wan/modules/tokenizers.py
  function basic_clean (line 12) | def basic_clean(text):
  function whitespace_clean (line 18) | def whitespace_clean(text):
  function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
  class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
    method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
    method __call__ (line 49) | def __call__(self, sequence, **kwargs):
    method _clean (line 75) | def _clean(self, text):

FILE: wan/modules/vace_model.py
  class VaceWanAttentionBlock (line 10) | class VaceWanAttentionBlock(WanAttentionBlock):
    method __init__ (line 12) | def __init__(self,
    method forward (line 33) | def forward(self, c, x, **kwargs):
  class BaseWanAttentionBlock (line 42) | class BaseWanAttentionBlock(WanAttentionBlock):
    method __init__ (line 44) | def __init__(self,
    method forward (line 58) | def forward(self, x, hints, context_scale=1.0, **kwargs):
  class VaceWanModel (line 65) | class VaceWanModel(WanModel):
    method __init__ (line 68) | def __init__(self,
    method forward_vace (line 136) | def forward_vace(self, x, vace_context, seq_len, kwargs):
    method forward (line 155) | def forward(

FILE: wan/modules/vae.py
  class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
    method __init__ (line 22) | def __init__(self, *args, **kwargs):
    method forward (line 28) | def forward(self, x, cache_x=None):
  class RMS_norm (line 39) | class RMS_norm(nn.Module):
    method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
    method forward (line 51) | def forward(self, x):
  class Upsample (line 57) | class Upsample(nn.Upsample):
    method forward (line 59) | def forward(self, x):
  class Resample (line 66) | class Resample(nn.Module):
    method __init__ (line 68) | def __init__(self, dim, mode):
    method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
    method init_weight (line 162) | def init_weight(self, conv):
    method init_weight2 (line 174) | def init_weight2(self, conv):
  class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
    method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
    method __init__ (line 228) | def __init__(self, dim):
    method forward (line 240) | def forward(self, x):
  class Encoder3d (line 265) | class Encoder3d(nn.Module):
    method __init__ (line 267) | def __init__(self,
    method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class Decoder3d (line 369) | class Decoder3d(nn.Module):
    method __init__ (line 371) | def __init__(self,
    method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  function count_conv3d (line 475) | def count_conv3d(model):
  class WanVAE_ (line 483) | class WanVAE_(nn.Module):
    method __init__ (line 485) | def __init__(self,
    method forward (line 510) | def forward(self, x):
    method encode (line 516) | def encode(self, x, scale):
    method decode (line 544) | def decode(self, z, scale):
    method reparameterize (line 570) | def reparameterize(self, mu, log_var):
    method sample (line 575) | def sample(self, imgs, deterministic=False):
    method clear_cache (line 582) | def clear_cache(self):
  function _video_vae (line 592) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
  class WanVAE (line 619) | class WanVAE:
    method __init__ (line 621) | def __init__(self,
    method encode (line 647) | def encode(self, videos):
    method decode (line 657) | def decode(self, zs):

FILE: wan/modules/xlm_roberta.py
  class SelfAttention (line 10) | class SelfAttention(nn.Module):
    method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
    method forward (line 27) | def forward(self, x, mask):
  class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
    method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
    method forward (line 66) | def forward(self, x, mask):
  class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
    method __init__ (line 81) | def __init__(self,
    method forward (line 118) | def forward(self, ids):
  function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,

FILE: wan/utils/fm_solvers.py
  function get_sampling_sigmas (line 24) | def get_sampling_sigmas(sampling_steps, shift):
  function retrieve_timesteps (line 31) | def retrieve_timesteps(
  class FlowDPMSolverMultistepScheduler (line 71) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 131) | def __init__(
    method step_index (line 204) | def step_index(self):
    method begin_index (line 211) | def begin_index(self):
    method set_begin_index (line 218) | def set_begin_index(self, begin_index: int = 0):
    method set_timesteps (line 228) | def set_timesteps(
    method _threshold_sample (line 294) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
    method _sigma_to_t (line 332) | def _sigma_to_t(self, sigma):
    method _sigma_to_alpha_sigma_t (line 335) | def _sigma_to_alpha_sigma_t(self, sigma):
    method time_shift (line 339) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
    method convert_model_output (line 343) | def convert_model_output(
    method dpm_solver_first_order_update (line 417) | def dpm_solver_first_order_update(
    method multistep_dpm_solver_second_order_update (line 488) | def multistep_dpm_solver_second_order_update(
    method multistep_dpm_solver_third_order_update (line 598) | def multistep_dpm_solver_third_order_update(
    method index_for_timestep (line 681) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 695) | def _init_step_index(self, timestep):
    method step (line 708) | def step(
    method scale_model_input (line 802) | def scale_model_input(self, sample: torch.Tensor, *args,
    method add_noise (line 817) | def add_noise(
    method __len__ (line 858) | def __len__(self):

FILE: wan/utils/fm_solvers_unipc.py
  class FlowUniPCMultistepScheduler (line 22) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 79) | def __init__(
    method step_index (line 137) | def step_index(self):
    method begin_index (line 144) | def begin_index(self):
    method set_begin_index (line 151) | def set_begin_index(self, begin_index: int = 0):
    method set_timesteps (line 162) | def set_timesteps(
    method _threshold_sample (line 232) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
    method _sigma_to_t (line 271) | def _sigma_to_t(self, sigma):
    method _sigma_to_alpha_sigma_t (line 274) | def _sigma_to_alpha_sigma_t(self, sigma):
    method time_shift (line 278) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
    method convert_model_output (line 281) | def convert_model_output(
    method multistep_uni_p_bh_update (line 352) | def multistep_uni_p_bh_update(
    method multistep_uni_c_bh_update (line 488) | def multistep_uni_c_bh_update(
    method index_for_timestep (line 630) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 645) | def _init_step_index(self, timestep):
    method step (line 657) | def step(self,
    method scale_model_input (line 743) | def scale_model_input(self, sample: torch.Tensor, *args,
    method add_noise (line 760) | def add_noise(
    method __len__ (line 801) | def __len__(self):

FILE: wan/utils/motion.py
  function get_tracks_inference (line 23) | def get_tracks_inference(tracks, height, width, quant_multi: Optional[in...
  function unzip_to_array (line 36) | def unzip_to_array(
  function process_tracks (line 53) | def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], q...

FILE: wan/utils/prompt_extend.py
  class PromptOutput (line 153) | class PromptOutput(object):
    method add_custom_field (line 160) | def add_custom_field(self, key: str, value) -> None:
  class PromptExpander (line 164) | class PromptExpander:
    method __init__ (line 166) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
    method extend_with_img (line 171) | def extend_with_img(self,
    method extend (line 180) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method decide_system_prompt (line 183) | def decide_system_prompt(self, tar_lang="zh", multi_images_input=False):
    method __call__ (line 189) | def __call__(self,
  class DashScopePromptExpander (line 213) | class DashScopePromptExpander(PromptExpander):
    method __init__ (line 215) | def __init__(self,
    method extend (line 252) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method extend_with_img (line 288) | def extend_with_img(self,
  class QwenPromptExpander (line 364) | class QwenPromptExpander(PromptExpander):
    method __init__ (line 373) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
    method extend (line 433) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method extend_with_img (line 464) | def extend_with_img(self,

FILE: wan/utils/qwen_vl_utils.py
  function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
  function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
  function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
  function smart_resize (line 54) | def smart_resize(height: int,
  function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
  function smart_nframes (line 133) | def smart_nframes(
  function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
  function is_decord_available (line 215) | def is_decord_available() -> bool:
  function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
  function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
  function fetch_video (line 274) | def fetch_video(
  function extract_vision_info (line 328) | def extract_vision_info(
  function process_vision_info (line 344) | def process_vision_info(

FILE: wan/utils/utils.py
  function rand_name (line 14) | def rand_name(length=8, suffix=''):
  function cache_video (line 23) | def cache_video(tensor,
  function cache_image (line 64) | def cache_image(tensor,
  function str2bool (line 94) | def str2bool(v):

FILE: wan/utils/vace_processor.py
  class VaceImageProcessor (line 9) | class VaceImageProcessor(object):
    method __init__ (line 11) | def __init__(self, downsample=None, seq_len=None):
    method _pillow_convert (line 15) | def _pillow_convert(self, image, cvt_type='RGB'):
    method _load_image (line 30) | def _load_image(self, img_path):
    method _resize_crop (line 37) | def _resize_crop(self, img, oh, ow, normalize=True):
    method _image_preprocess (line 60) | def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
    method load_image (line 63) | def load_image(self, data_key, **kwargs):
    method load_image_pair (line 66) | def load_image_pair(self, data_key, data_key2, **kwargs):
    method load_image_batch (line 69) | def load_image_batch(self,
  class VaceVideoProcessor (line 91) | class VaceVideoProcessor(object):
    method __init__ (line 93) | def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
    method set_area (line 105) | def set_area(self, area):
    method set_seq_len (line 109) | def set_seq_len(self, seq_len):
    method resize_crop (line 113) | def resize_crop(video: torch.Tensor, oh: int, ow: int):
    method _video_preprocess (line 151) | def _video_preprocess(self, video, oh, ow):
    method _get_frameid_bbox_default (line 154) | def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_...
    method _get_frameid_bbox_adjust_last (line 187) | def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
    method _get_frameid_bbox (line 219) | def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
    method load_video (line 227) | def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
    method load_video_pair (line 231) | def load_video_pair(self,
    method load_video_batch (line 240) | def load_video_batch(self,
  function prepare_source (line 274) | def prepare_source(src_video, src_mask, src_ref_images, num_frames, imag...

FILE: wan/vace.py
  class WanVace (line 37) | class WanVace(WanT2V):
    method __init__ (line 39) | def __init__(
    method vace_encode_frames (line 139) | def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
    method vace_encode_masks (line 174) | def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
    method vace_latent (line 209) | def vace_latent(self, z, m):
    method prepare_source (line 212) | def prepare_source(self, src_video, src_mask, src_ref_images, num_frames,
    method decode_latent (line 280) | def decode_latent(self, zs, ref_images=None, vae=None):
    method generate (line 295) | def generate(self,
  class WanVaceMP (line 478) | class WanVaceMP(WanVace):
    method __init__ (line 480) | def __init__(self,
    method dynamic_load (line 512) | def dynamic_load(self):
    method transfer_data_to_cuda (line 544) | def transfer_data_to_cuda(self, data, device):
    method mp_worker (line 562) | def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
    method generate (line 773) | def generate(self,
Condensed preview — 59 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (499K chars).
[
  {
    "path": ".gitignore",
    "chars": 577,
    "preview": ".*\n*.py[cod]\n# *.jpg\n*.jpeg\n# *.png\n*.gif\n*.bmp\n*.mp4\n*.mov\n*.mkv\n*.log\n*.zip\n*.pt\n*.pth\n*.ckpt\n*.safetensors\n*.json\n# *"
  },
  {
    "path": "INSTALL.md",
    "chars": 1150,
    "preview": "# Installation Guide\n\n## Install with pip\n\n```bash\npip install .\npip install .[dev]  # Installe aussi les outils de dev\n"
  },
  {
    "path": "LICENSE.txt",
    "chars": 11552,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "Makefile",
    "chars": 94,
    "preview": ".PHONY: format\n\nformat:\n\tisort generate.py gradio wan\n\tyapf -i -r *.py generate.py gradio wan\n"
  },
  {
    "path": "README.md",
    "chars": 10660,
    "preview": "# ATI: Any Trajectory Instruction for Controllable Video Generation\n\n<div align=\"center\">\n  \n[![arXiv](https://img.shiel"
  },
  {
    "path": "examples/test.yaml",
    "chars": 1308,
    "preview": "- image: examples/images/fish.jpg\n  text: \"A tranquil koi pond edged by mossy stone, with lily pads drifting on the surf"
  },
  {
    "path": "generate.py",
    "chars": 13252,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n# Copyright (c) 2024-2025 Bytedance Ltd. and/or"
  },
  {
    "path": "gradio/fl2v_14B_singleGPU.py",
    "chars": 8277,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport gc\nimport os\nimport os.p"
  },
  {
    "path": "gradio/i2v_14B_singleGPU.py",
    "chars": 9282,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport gc\nimport os\nimport os.p"
  },
  {
    "path": "gradio/t2i_14B_singleGPU.py",
    "chars": 6614,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport os\nimport os.path as osp"
  },
  {
    "path": "gradio/t2v_1.3B_singleGPU.py",
    "chars": 6598,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport os\nimport os.path as osp"
  },
  {
    "path": "gradio/t2v_14B_singleGPU.py",
    "chars": 6598,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport os\nimport os.path as osp"
  },
  {
    "path": "gradio/vace.py",
    "chars": 12925,
    "preview": "# -*- coding: utf-8 -*-\n# Copyright (c) Alibaba, Inc. and its affiliates.\n\nimport argparse\nimport datetime\nimport os\nimp"
  },
  {
    "path": "pyproject.toml",
    "chars": 1339,
    "preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"wan\"\nversion ="
  },
  {
    "path": "requirements.txt",
    "chars": 234,
    "preview": "torch>=2.4.0\ntorchvision>=0.19.0\nopencv-python>=4.9.0.80\ndiffusers>=0.31.0\ntransformers>=4.49.0\ntokenizers>=0.20.3\naccel"
  },
  {
    "path": "run_example.sh",
    "chars": 2040,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#!/usr/bin/env bash\nset -euo pipefail\n\nusage() {\n  cat <<"
  },
  {
    "path": "tests/README.md",
    "chars": 216,
    "preview": "\nPut all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify"
  },
  {
    "path": "tests/test.sh",
    "chars": 5476,
    "preview": "#!/bin/bash\n\n\nif [ \"$#\" -eq 2 ]; then\n  MODEL_DIR=$(realpath \"$1\")\n  GPUS=$2\nelse\n  echo \"Usage: $0 <local model dir> <g"
  },
  {
    "path": "tools/get_track_from_videos.py",
    "chars": 13635,
    "preview": "import torch\nfrom typing import List, Sequence, Any\nfrom PIL import Image\nimport numpy as np\nimport cv2\nimport yaml\nimpo"
  },
  {
    "path": "tools/plot_user_inputs.py",
    "chars": 5959,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \""
  },
  {
    "path": "tools/trajectory_editor/app.py",
    "chars": 7360,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \""
  },
  {
    "path": "tools/trajectory_editor/templates/index.html",
    "chars": 19315,
    "preview": "<!-- Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (th"
  },
  {
    "path": "tools/visualize_trajectory.py",
    "chars": 6806,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \""
  },
  {
    "path": "wan/__init__.py",
    "chars": 223,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n# SPDX-License-Identifier: Apache-2.0\n\nfrom . import conf"
  },
  {
    "path": "wan/configs/__init__.py",
    "chars": 965,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport copy\nimport os\n\nos.environ['TOKENIZERS_P"
  },
  {
    "path": "wan/configs/shared_config.py",
    "chars": 649,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n#--"
  },
  {
    "path": "wan/configs/wan_i2v_14B.py",
    "chars": 1035,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
  },
  {
    "path": "wan/configs/wan_t2v_14B.py",
    "chars": 742,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
  },
  {
    "path": "wan/configs/wan_t2v_1_3B.py",
    "chars": 759,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
  },
  {
    "path": "wan/distributed/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "wan/distributed/fsdp.py",
    "chars": 1307,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nfrom functools import partial\n\nimport"
  },
  {
    "path": "wan/distributed/xdit_context_parallel.py",
    "chars": 6839,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
  },
  {
    "path": "wan/first_last_frame2video.py",
    "chars": 14622,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
  },
  {
    "path": "wan/image2video.py",
    "chars": 13480,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n# Copyright (c) 2024-2025 Bytedance Ltd. and/or"
  },
  {
    "path": "wan/modules/__init__.py",
    "chars": 422,
    "preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
  },
  {
    "path": "wan/modules/attention.py",
    "chars": 5650,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n    import flash_attn_interf"
  },
  {
    "path": "wan/modules/clip.py",
    "chars": 16848,
    "preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
  },
  {
    "path": "wan/modules/model.py",
    "chars": 21340,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.amp as a"
  },
  {
    "path": "wan/modules/motion_patch.py",
    "chars": 5625,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \""
  },
  {
    "path": "wan/modules/t5.py",
    "chars": 16910,
    "preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
  },
  {
    "path": "wan/modules/tokenizers.py",
    "chars": 2431,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
  },
  {
    "path": "wan/modules/vace_model.py",
    "chars": 8281,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nimpor"
  },
  {
    "path": "wan/modules/vae.py",
    "chars": 23135,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
  },
  {
    "path": "wan/modules/xlm_roberta.py",
    "chars": 4865,
    "preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
  },
  {
    "path": "wan/utils/__init__.py",
    "chars": 402,
    "preview": "from .fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom ."
  },
  {
    "path": "wan/utils/fm_solvers.py",
    "chars": 40142,
    "preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
  },
  {
    "path": "wan/utils/fm_solvers_unipc.py",
    "chars": 32557,
    "preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
  },
  {
    "path": "wan/utils/motion.py",
    "chars": 2474,
    "preview": "# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \""
  },
  {
    "path": "wan/utils/prompt_extend.py",
    "chars": 39552,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
  },
  {
    "path": "wan/utils/qwen_vl_utils.py",
    "chars": 13054,
    "preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
  },
  {
    "path": "wan/utils/utils.py",
    "chars": 3256,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
  },
  {
    "path": "wan/utils/vace_processor.py",
    "chars": 11914,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport numpy as np\nimport torch\nimport torch.nn"
  },
  {
    "path": "wan/vace.py",
    "chars": 32116,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
  }
]

// ... and 6 more files (download for full content)

About this extraction

This page contains the full source code of the bytedance/ATI GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 59 files (461.8 KB), approximately 117.2k tokens, and a symbol index with 360 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!