[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n.DS_Store\n.vscode*\ntmp_examples*\nnew_checkpoint*\nbatch_test*\nnohup*"
  },
  {
    "path": "INSTALL.md",
    "content": "# Installation Guide\n\n## Install with pip\n\n```bash\npip install .\npip install .[dev]  # Installe aussi les outils de dev\n```\n\n## Install with Poetry\n\nEnsure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.\n\nTo install all dependencies:\n\n```bash\npoetry install\n```\n\n### Handling `flash-attn` Installation Issues\n\nIf `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.\n\n#### No-Build-Isolation Installation (Recommended)\n```bash\npoetry run pip install --upgrade pip setuptools wheel\npoetry run pip install flash-attn --no-build-isolation\npoetry install\n```\n\n#### Install from Git (Alternative)\n```bash\npoetry run pip install git+https://github.com/Dao-AILab/flash-attention.git\n```\n\n---\n\n### Running the Model\n\nOnce the installation is complete, you can run **Wan2.2** using:\n\n```bash\npoetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\"\n```\n\n#### Test\n```bash\nbash tests/test.sh\n```\n\n#### Format\n```bash\nblack .\nisort .\n```\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: format\n\nformat:\n\tisort generate.py wan\n\tyapf -i -r *.py generate.py wan\n"
  },
  {
    "path": "README.md",
    "content": "# Wan2.2\n\n<p align=\"center\">\n    <img src=\"assets/logo.png\" width=\"400\"/>\n<p>\n\n<p align=\"center\">\n    💜 <a href=\"https://wan.video\"><b>Wan</b></a> &nbsp&nbsp ｜ &nbsp&nbsp 🖥️ <a href=\"https://github.com/Wan-Video/Wan2.2\">GitHub</a> &nbsp&nbsp  | &nbsp&nbsp🤗 <a href=\"https://huggingface.co/Wan-AI/\">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href=\"https://modelscope.cn/organization/Wan-AI\">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href=\"https://arxiv.org/abs/2503.20314\">Paper</a> &nbsp&nbsp | &nbsp&nbsp 📑 <a href=\"https://wan.video/welcome?spm=a2ty_o02.30011076.0.0.6c9ee41eCcluqg\">Blog</a> &nbsp&nbsp |  &nbsp&nbsp 💬  <a href=\"https://discord.gg/AKNgpMK4Yj\">Discord</a>&nbsp&nbsp\n    <br>\n    📕 <a href=\"https://alidocs.dingtalk.com/i/nodes/jb9Y4gmKWrx9eo4dCql9LlbYJGXn6lpz\">使用指南(中文)</a>&nbsp&nbsp | &nbsp&nbsp 📘 <a href=\"https://alidocs.dingtalk.com/i/nodes/EpGBa2Lm8aZxe5myC99MelA2WgN7R35y\">User Guide(English)</a>&nbsp&nbsp | &nbsp&nbsp💬 <a href=\"https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg\">WeChat(微信)</a>&nbsp&nbsp\n<br>\n\n-----\n\n[**Wan: Open and Advanced Large-Scale Video Generative Models**](https://arxiv.org/abs/2503.20314) <be>\n\n\nWe are excited to introduce **Wan2.2**, a major upgrade to our foundational video models. With **Wan2.2**, we have focused on incorporating the following innovations:\n\n- 👍 **Effective MoE Architecture**: Wan2.2 introduces a Mixture-of-Experts (MoE) architecture into video diffusion models. By separating the denoising process cross timesteps with specialized powerful expert models, this enlarges the overall model capacity while maintaining the same computational cost.\n\n- 👍 **Cinematic-level Aesthetics**: Wan2.2 incorporates meticulously curated aesthetic data, complete with detailed labels for lighting, composition, contrast, color tone, and more. This allows for more precise and controllable cinematic style generation, facilitating the creation of videos with customizable aesthetic preferences.\n\n- 👍 **Complex Motion Generation**: Compared to Wan2.1, Wan2.2 is trained on a significantly larger data, with +65.6% more images and +83.2% more videos. This expansion notably enhances the model's generalization across multiple dimensions such as motions,  semantics, and aesthetics, achieving TOP performance among all open-sourced and closed-sourced models. \n\n- 👍 **Efficient High-Definition Hybrid TI2V**:  Wan2.2 open-sources a 5B model built with our advanced Wan2.2-VAE that achieves a compression ratio of **16×16×4**. This model supports both text-to-video and image-to-video generation at 720P resolution with 24fps and can also run on consumer-grade graphics cards like 4090. It is one of the fastest **720P@24fps** models currently available, capable of serving both the industrial and academic sectors simultaneously.\n\n\n## Video Demos\n\n<div align=\"center\">\n  <video src=\"https://github.com/user-attachments/assets/b63bfa58-d5d7-4de6-a1a2-98970b06d9a7\" width=\"70%\" poster=\"\"> </video>\n</div>\n\n## 🔥 Latest News!!\n* Nov 13, 2025: 👋 Wan2.2-Animate-14B has been integrated into Diffusers ([PR](https://github.com/huggingface/diffusers/pull/12526),[Weights](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)). Thanks to all community contributors. Enjoy!\n\n* Sep 19, 2025: 💃 We introduct **[Wan2.2-Animate-14B](https://humanaigc.github.io/wan-animate)**, an unified model for character animation and replacement with holistic movement and expression replication. We released the [model weights](#model-download) and [inference code](#run-wan-animate). And you can try it on [wan.video](https://wan.video/), [ModelScope Studio](https://www.modelscope.cn/studios/Wan-AI/Wan2.2-Animate) or [HuggingFace Space](https://huggingface.co/spaces/Wan-AI/Wan2.2-Animate)!\n* Aug 26, 2025: 🎵 We introduce **[Wan2.2-S2V-14B](https://humanaigc.github.io/wan-s2v-webpage)**, an audio-driven cinematic video generation model, including [inference code](#run-speech-to-video-generation), [model weights](#model-download), and [technical report](https://humanaigc.github.io/wan-s2v-webpage/content/wan-s2v.pdf)! Now you can try it on [wan.video](https://wan.video/),  [ModelScope Gradio](https://www.modelscope.cn/studios/Wan-AI/Wan2.2-S2V) or [HuggingFace Gradio](https://huggingface.co/spaces/Wan-AI/Wan2.2-S2V)!\n* Jul 28, 2025: 👋 We have open a [HF space](https://huggingface.co/spaces/Wan-AI/Wan-2.2-5B) using the TI2V-5B model. Enjoy!\n* Jul 28, 2025: 👋 Wan2.2 has been integrated into ComfyUI ([CN](https://docs.comfy.org/zh-CN/tutorials/video/wan/wan2_2) | [EN](https://docs.comfy.org/tutorials/video/wan/wan2_2)). Enjoy!\n* Jul 28, 2025: 👋 Wan2.2's T2V, I2V and TI2V have been integrated into Diffusers ([T2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) | [I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | [TI2V-5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)). Feel free to give it a try!\n* Jul 28, 2025: 👋 We've released the inference code and model weights of **Wan2.2**.\n* Sep 5, 2025: 👋 We add text-to-speech synthesis support with [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for Speech-to-Video generation task.\n\n\n## Community Works\nIf your research or project builds upon [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) or [**Wan2.2**](https://github.com/Wan-Video/Wan2.2), and you would like more people to see it, please inform us.\n\n- [Prompt Relay](https://github.com/GordonChen19/Prompt-Relay), a plug-and-play, inference-time method for temporal control in video generation. Prompt Relay improves video quality and gives users precise control over what happens at each moment in the video. Visit their [webpage](https://gordonchen19.github.io/Prompt-Relay/) for more details.\n- [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model base on **Wan2.1** that achieves minute-scale, high-quality video synthesis at 19.5 FPS on a single H100 GPU (about 10 FPS on a single Ascend NPU) —without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Visit their [webpage](https://pku-yuangroup.github.io/Helios-Page/) for more details.\n- [LightX2V](https://github.com/ModelTC/LightX2V), a lightweight and efficient video generation framework that integrates **Wan2.1** and **Wan2.2**, supporting multiple engineering acceleration techniques for fast inference. [LightX2V-HuggingFace](https://huggingface.co/lightx2v), offers a variety of Wan-based step-distillation models, quantized models, and lightweight VAE models.\n- [HuMo](https://github.com/Phantom-video/HuMo) proposed a unified, human-centric framework based on **Wan** to produce high-quality, fine-grained, and controllable human videos from multimodal inputs—including text, images, and audio. Visit their [webpage](https://phantom-video.github.io/HuMo/) for more details.\n- [FastVideo](https://github.com/hao-ai-lab/FastVideo) includes distilled **Wan** models with sparse attention that significanly speed up the inference time. \n- [Cache-dit](https://github.com/vipshop/cache-dit) offers Fully Cache Acceleration support for **Wan2.2** MoE with DBCache, TaylorSeer and Cache CFG. Visit their [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) for more details.\n- [Kijai's ComfyUI WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper) is an alternative implementation of **Wan** models for ComfyUI. Thanks to its Wan-only focus, it's on the frontline of getting cutting edge optimizations and hot research features, which are often hard to integrate into ComfyUI quickly due to its more rigid structure.\n- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides comprehensive support for **Wan 2.2**, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training.\n\n\n## 📑 Todo List\n- Wan2.2 Text-to-Video\n    - [x] Multi-GPU Inference code of the A14B and 14B models\n    - [x] Checkpoints of the A14B and 14B models\n    - [x] ComfyUI integration\n    - [x] Diffusers integration\n- Wan2.2 Image-to-Video\n    - [x] Multi-GPU Inference code of the A14B model\n    - [x] Checkpoints of the A14B model\n    - [x] ComfyUI integration\n    - [x] Diffusers integration\n- Wan2.2 Text-Image-to-Video\n    - [x] Multi-GPU Inference code of the 5B model\n    - [x] Checkpoints of the 5B model\n    - [x] ComfyUI integration\n    - [x] Diffusers integration\n- Wan2.2-S2V Speech-to-Video\n    - [x] Inference code of Wan2.2-S2V\n    - [x] Checkpoints of Wan2.2-S2V-14B\n    - [x] ComfyUI integration\n    - [x] Diffusers integration\n- Wan2.2-Animate Character Animation and Replacement\n    - [x] Inference code of Wan2.2-Animate\n    - [x] Checkpoints of Wan2.2-Animate\n    - [x] ComfyUI integration\n    - [x] Diffusers integration\n\n## Run Wan2.2\n\n#### Installation\nClone the repo:\n```sh\ngit clone https://github.com/Wan-Video/Wan2.2.git\ncd Wan2.2\n```\n\nInstall dependencies:\n```sh\n# Ensure torch >= 2.4.0\n# If the installation of `flash_attn` fails, try installing the other packages first and install `flash_attn` last\npip install -r requirements.txt\n# If you want to use CosyVoice to synthesize speech for Speech-to-Video Generation, please install requirements_s2v.txt additionally\npip install -r requirements_s2v.txt\n```\n\n\n#### Model Download\n\n| Models              | Download Links                                                                                                                              | Description |\n|--------------------|---------------------------------------------------------------------------------------------------------------------------------------------|-------------|\n| T2V-A14B    | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B)    🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)    | Text-to-Video MoE model, supports 480P & 720P |\n| I2V-A14B    | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B)    🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)    | Image-to-Video MoE model, supports 480P & 720P |\n| TI2V-5B     | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)     🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)     | High-compression VAE, T2V+I2V, supports 720P |\n| S2V-14B     | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B)     🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)     | Speech-to-Video model, supports 480P & 720P |\n| Animate-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)  | Character animation and replacement | |\n\n\n\n> 💡Note: \n> The TI2V-5B model supports 720P video generation at **24 FPS**.\n\n\nDownload models using huggingface-cli:\n``` sh\npip install \"huggingface_hub[cli]\"\nhuggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B\n```\n\nDownload models using modelscope-cli:\n``` sh\npip install modelscope\nmodelscope download Wan-AI/Wan2.2-T2V-A14B --local_dir ./Wan2.2-T2V-A14B\n```\n\n#### Run Text-to-Video Generation\n\nThis repository supports the `Wan2.2-T2V-A14B` Text-to-Video model and can simultaneously support video generation at 480P and 720P resolutions.\n\n\n##### (1) Without Prompt Extension\n\nTo facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step.\n\n- Single-GPU inference\n\n``` sh\npython generate.py  --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --offload_model True --convert_model_dtype --prompt \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\"\n```\n\n> 💡 This command can run on a GPU with at least 80GB VRAM.\n\n> 💡If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True`, `--convert_model_dtype` and `--t5_cpu` options to reduce GPU memory usage.\n\n\n- Multi-GPU inference using FSDP + DeepSpeed Ulysses\n\n  We use [PyTorch FSDP](https://docs.pytorch.org/docs/stable/fsdp.html) and [DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509) to accelerate inference.\n\n\n``` sh\ntorchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\"\n```\n\n\n##### (2) Using Prompt Extension\n\nExtending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension:\n\n- Use the Dashscope API for extension.\n  - Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)).\n  - Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1).\n  - Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks.\n  - You can modify the model used for extension with the parameter `--prompt_extend_model`. For example:\n```sh\nDASH_API_KEY=your_key torchrun --nproc_per_node=8 generate.py  --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage\" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'zh'\n```\n\n- Using a local model for extension.\n\n  - By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size.\n  - For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.\n  - For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.\n  - Larger models generally provide better extension results but require more GPU memory.\n  - You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:\n\n``` sh\ntorchrun --nproc_per_node=8 generate.py  --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage\" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'zh'\n```\n\n\n#### Run Image-to-Video Generation\n\nThis repository supports the `Wan2.2-I2V-A14B` Image-to-Video model and can simultaneously support video generation at 480P and 720P resolutions.\n\n\n- Single-GPU inference\n```sh\npython generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --offload_model True --convert_model_dtype --image examples/i2v_input.JPG --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\"\n```\n\n> This command can run on a GPU with at least 80GB VRAM.\n\n> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.\n\n\n- Multi-GPU inference using FSDP + DeepSpeed Ulysses\n\n```sh\ntorchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\"\n```\n\n- Image-to-Video Generation without prompt\n\n```sh\nDASH_API_KEY=your_key torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --prompt '' --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --use_prompt_extend --prompt_extend_method 'dashscope'\n```\n\n> 💡The model can generate videos solely from the input image. You can use prompt extension to generate prompt from the image.\n\n> The process of prompt extension can be referenced [here](#2-using-prompt-extention).\n\n#### Run Text-Image-to-Video Generation\n\nThis repository supports the `Wan2.2-TI2V-5B` Text-Image-to-Video model and can support video generation at 720P resolutions.\n\n\n- Single-GPU Text-to-Video inference\n```sh\npython generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --offload_model True --convert_model_dtype --t5_cpu --prompt \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage\"\n```\n\n> 💡Unlike other tasks, the 720P resolution of the Text-Image-to-Video task is `1280*704` or `704*1280`.\n\n> This command can run on a GPU with at least 24GB VRAM (e.g, RTX 4090 GPU).\n\n> 💡If you are running on a GPU with at least 80GB VRAM, you can remove the `--offload_model True`, `--convert_model_dtype` and `--t5_cpu` options to speed up execution.\n\n\n- Single-GPU Image-to-Video inference\n```sh\npython generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --offload_model True --convert_model_dtype --t5_cpu --image examples/i2v_input.JPG --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\"\n```\n\n> 💡If the image parameter is configured, it is an Image-to-Video generation; otherwise, it defaults to a Text-to-Video generation.\n\n> 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.\n\n\n- Multi-GPU inference using FSDP + DeepSpeed Ulysses\n\n```sh\ntorchrun --nproc_per_node=8 generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --dit_fsdp --t5_fsdp --ulysses_size 8 --image examples/i2v_input.JPG --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\"\n```\n\n> The process of prompt extension can be referenced [here](#2-using-prompt-extention).\n\n#### Run Speech-to-Video Generation\n\nThis repository supports the `Wan2.2-S2V-14B` Speech-to-Video model and can simultaneously support video generation at 480P and 720P resolutions.\n\n- Single-GPU Speech-to-Video inference\n\n```sh\npython generate.py  --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --offload_model True --convert_model_dtype --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard.\"  --image \"examples/i2v_input.JPG\" --audio \"examples/talk.wav\"\n# Without setting --num_clip, the generated video length will automatically adjust based on the input audio length\n\n# You can use CosyVoice to generate audio with --enable_tts\npython generate.py  --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --offload_model True --convert_model_dtype --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard.\"  --image \"examples/i2v_input.JPG\" --enable_tts --tts_prompt_audio \"examples/zero_shot_prompt.wav\" --tts_prompt_text \"希望你以后能够做的比我还好呦。\" --tts_text \"收到好友从远方寄来的生日礼物，那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐，笑容如花儿般绽放。\"\n```\n\n> 💡 This command can run on a GPU with at least 80GB VRAM.\n\n- Multi-GPU inference using FSDP + DeepSpeed Ulysses\n\n```sh\ntorchrun --nproc_per_node=8 generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard.\" --image \"examples/i2v_input.JPG\" --audio \"examples/talk.wav\"\n```\n\n- Pose + Audio driven generation\n\n```sh\ntorchrun --nproc_per_node=8 generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt \"a person is singing\" --image \"examples/pose.png\" --audio \"examples/sing.MP3\" --pose_video \"./examples/pose.mp4\" \n```\n\n> 💡For the Speech-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.\n\n> 💡The model can generate videos from audio input combined with reference image and optional text prompt.\n\n> 💡The `--pose_video` parameter enables pose-driven generation, allowing the model to follow specific pose sequences while generating videos synchronized with audio input.\n\n> 💡The `--num_clip` parameter controls the number of video clips generated, useful for quick preview with shorter generation time.\n\nPlease visit our project page to see more examples and learn about the scenarios suitable for this model.\n\n#### Run Wan-Animate \n\nWan-Animate takes a video and a character image as input, and generates a video in either \"animation\" or \"replacement\" mode. \n\n1. animation mode： The model generates a video of the character image that mimics the human motion in the input video.\n2. replacement mode: The model replaces the character image with the input video.\n\nPlease visit our [project page](https://humanaigc.github.io/wan-animate) to see more examples and learn about the scenarios suitable for this model.\n\n##### (1) Preprocessing \nThe input video should be preprocessed into several materials before be feed into the inference process.  Please refer to the following processing flow, and more details about preprocessing can be found in [UserGuider](https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/animate/preprocess/UserGuider.md).\n\n* For animation\n```bash\npython ./wan/modules/animate/preprocess/preprocess_data.py \\\n    --ckpt_path ./Wan2.2-Animate-14B/process_checkpoint \\\n    --video_path ./examples/wan_animate/animate/video.mp4 \\\n    --refer_path ./examples/wan_animate/animate/image.jpeg \\\n    --save_path ./examples/wan_animate/animate/process_results \\\n    --resolution_area 1280 720 \\\n    --retarget_flag \\\n    --use_flux\n```\n* For replacement\n```bash\npython ./wan/modules/animate/preprocess/preprocess_data.py \\\n    --ckpt_path ./Wan2.2-Animate-14B/process_checkpoint \\\n    --video_path ./examples/wan_animate/replace/video.mp4 \\\n    --refer_path ./examples/wan_animate/replace/image.jpeg \\\n    --save_path ./examples/wan_animate/replace/process_results \\\n    --resolution_area 1280 720 \\\n    --iterations 3 \\\n    --k 7 \\\n    --w_len 1 \\\n    --h_len 1 \\\n    --replace_flag\n```\n##### (2) Run in animation mode \n\n* Single-GPU inference \n\n```bash\npython generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/animate/process_results/ --refert_num 1\n```\n\n* Multi-GPU inference using FSDP + DeepSpeed Ulysses\n\n```bash\npython -m torch.distributed.run --nnodes 1 --nproc_per_node 8 generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/animate/process_results/ --refert_num 1 --dit_fsdp --t5_fsdp --ulysses_size 8\n```\n\n* Diffusers Pipeline\n\n```python\nfrom diffusers import WanAnimatePipeline\nfrom diffusers.utils import export_to_video, load_image, load_video\n\ndevice = \"cuda:0\"\ndtype = torch.bfloat16\nmodel_id = \"Wan-AI/Wan2.2-Animate-14B-Diffusers\"\npipe = WanAnimatePipeline.from_pretrained(model_id torch_dtype=dtype)\npipe.to(device)\n\nseed = 42\nprompt = \"People in the video are doing actions.\"\n\n# Animation\nimage = load_image(\"/path/to/animate/reference/image/src_ref.png\")\npose_video = load_video(\"/path/to/animate/pose/video/src_pose.mp4\")\nface_video = load_video(\"/path/to/animate/face/video/src_face.mp4\")\n\nanimate_video = pipe(\n    image=image,\n    pose_video=pose_video,\n    face_video=face_video,\n    prompt=prompt,\n    mode=\"animate\",\n    segment_frame_length=77,  # clip_len in original code\n    prev_segment_conditioning_frames=1,  # refert_num in original code\n    guidance_scale=1.0,\n    num_inference_steps=20,\n    generator=torch.Generator(device=device).manual_seed(seed),\n).frames[0]\nexport_to_video(animate_video, \"diffusers_animate.mp4\", fps=30)\n```\n\n##### (3) Run in replacement mode \n\n* Single-GPU inference \n\n```bash\npython generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/replace/process_results/ --refert_num 1 --replace_flag --use_relighting_lora \n```\n\n* Multi-GPU inference using FSDP + DeepSpeed Ulysses\n\n```bash\npython -m torch.distributed.run --nnodes 1 --nproc_per_node 8 generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/replace/process_results/src_pose.mp4  --refert_num 1 --replace_flag --use_relighting_lora --dit_fsdp --t5_fsdp --ulysses_size 8\n```\n\n* Diffusers Pipeline\n\n```python\n# create pipeline as in the Animation code ☝️\n\n# Replacement\nimage = load_image(\"/path/to/replace/reference/image/src_ref.png\")\npose_video = load_video(\"/path/to/replace/pose/video/src_pose.mp4\")\nface_video = load_video(\"/path/to/replace/face/video/src_face.mp4\")\nbackground_video = load_video(\"/path/to/replace/background/video/src_bg.mp4\")\nmask_video = load_video(\"/path/to/replace/mask/video/src_mask.mp4\")\n\nreplace_video = pipe(\n    image=image,\n    pose_video=pose_video,\n    face_video=face_video,\n    background_video=background_video,\n    mask_video=mask_video,\n    prompt=prompt,\n    mode=\"replace\",\n    segment_frame_length=77,  # clip_len in original code\n    prev_segment_conditioning_frames=1,  # refert_num in original code\n    guidance_scale=1.0,\n    num_inference_steps=20,\n    generator=torch.Generator(device=device).manual_seed(seed),\n).frames[0]\nexport_to_video(replace_video, \"diffusers_replace.mp4\", fps=30)\n```\n\n> 💡 If you're using **Wan-Animate**, we do not recommend using LoRA models trained on `Wan2.2`, since weight changes during training may lead to unexpected behavior.\n\n## Computational Efficiency on Different GPUs\n\nWe test the computational efficiency of different **Wan2.2** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.\n\n\n<div align=\"center\">\n    <img src=\"assets/comp_effic.png\" alt=\"\" style=\"width: 80%;\" />\n</div>\n\n> The parameter settings for the tests presented in this table are as follows:\n> (1) Multi-GPU: 14B: `--ulysses_size 4/8 --dit_fsdp --t5_fsdp`, 5B: `--ulysses_size 4/8 --offload_model True --convert_model_dtype --t5_cpu`; Single-GPU: 14B: `--offload_model True --convert_model_dtype`, 5B: `--offload_model True --convert_model_dtype --t5_cpu`\n(--convert_model_dtype converts model parameter types to config.param_dtype);\n> (2) The distributed testing utilizes the built-in FSDP and Ulysses implementations, with FlashAttention3 deployed on Hopper architecture GPUs;\n> (3) Tests were run without the `--use_prompt_extend` flag;\n> (4) Reported results are the average of multiple samples taken after the warm-up phase.\n\n\n-------\n\n## Introduction of Wan2.2\n\n**Wan2.2** builds on the foundation of Wan2.1 with notable improvements in generation quality and model capability. This upgrade is driven by a series of key technical innovations, mainly including the Mixture-of-Experts (MoE) architecture, upgraded training data, and high-compression video generation.\n\n##### (1) Mixture-of-Experts (MoE) Architecture\n\nWan2.2 introduces Mixture-of-Experts (MoE) architecture into the video generation diffusion model. MoE has been widely validated in large language models as an efficient approach to increase total model parameters while keeping inference cost nearly unchanged. In Wan2.2, the A14B model series adopts a two-expert design tailored to the denoising process of diffusion models: a high-noise expert for the early stages, focusing on overall layout; and a low-noise expert for the later stages, refining video details. Each expert model has about 14B parameters, resulting in a total of 27B parameters but only 14B active parameters per step, keeping inference computation and GPU memory nearly unchanged.\n\n<div align=\"center\">\n    <img src=\"assets/moe_arch.png\" alt=\"\" style=\"width: 90%;\" />\n</div>\n\nThe transition point between the two experts is determined by the signal-to-noise ratio (SNR), a metric that decreases monotonically as the denoising step $t$ increases. At the beginning of the denoising process, $t$ is large and the noise level is high, so the SNR is at its minimum, denoted as ${SNR}_{min}$. In this stage, the high-noise expert is activated. We define a threshold step ${t}_{moe}$ corresponding to half of the ${SNR}_{min}$, and switch to the low-noise expert when $t<{t}_{moe}$.\n\n<div align=\"center\">\n    <img src=\"assets/moe_2.png\" alt=\"\" style=\"width: 90%;\" />\n</div>\n\nTo validate the effectiveness of the MoE architecture, four settings are compared based on their validation loss curves. The baseline **Wan2.1** model does not employ the MoE architecture. Among the MoE-based variants, the **Wan2.1 & High-Noise Expert** reuses the Wan2.1 model as the low-noise expert while uses the  Wan2.2's high-noise expert, while the **Wan2.1 & Low-Noise Expert** uses Wan2.1 as the high-noise expert and employ the Wan2.2's low-noise expert. The **Wan2.2 (MoE)** (our final version) achieves the lowest validation loss, indicating that its generated video distribution is closest to ground-truth and exhibits superior convergence.\n\n\n##### (2) Efficient High-Definition Hybrid TI2V\nTo enable more efficient deployment, Wan2.2 also explores a high-compression design. In addition to the 27B MoE models, a 5B dense model, i.e., TI2V-5B, is released. It is supported by a high-compression Wan2.2-VAE, which achieves a $T\\times H\\times W$ compression ratio of $4\\times16\\times16$, increasing the overall compression rate to 64 while maintaining high-quality video reconstruction. With an additional patchification layer, the total compression ratio of TI2V-5B reaches $4\\times32\\times32$. Without specific optimization, TI2V-5B can generate a 5-second 720P video in under 9 minutes on a single consumer-grade GPU, ranking among the fastest 720P@24fps video generation models. This model also natively supports both text-to-video and image-to-video tasks within a single unified framework, covering both academic research and practical applications.\n\n\n<div align=\"center\">\n    <img src=\"assets/vae.png\" alt=\"\" style=\"width: 80%;\" />\n</div>\n\n\n\n##### Comparisons to SOTAs\nWe compared Wan2.2 with leading closed-source commercial models on our new Wan-Bench 2.0, evaluating performance across multiple crucial dimensions. The results demonstrate that Wan2.2 achieves superior performance compared to these leading models.\n\n\n<div align=\"center\">\n    <img src=\"assets/performance.png\" alt=\"\" style=\"width: 90%;\" />\n</div>\n\n## Citation\nIf you find our work helpful, please cite us.\n\n```\n@article{wan2025,\n      title={Wan: Open and Advanced Large-Scale Video Generative Models}, \n      author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},\n      journal = {arXiv preprint arXiv:2503.20314},\n      year={2025}\n}\n```\n\n## License Agreement\nThe models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt).\n\n\n## Acknowledgements\n\nWe would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.\n\n\n\n## Contact Us\nIf you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)!\n\n"
  },
  {
    "path": "generate.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport logging\nimport os\nimport sys\nimport warnings\nfrom datetime import datetime\n\nwarnings.filterwarnings('ignore')\n\nimport random\n\nimport torch\nimport torch.distributed as dist\nfrom PIL import Image\n\nimport wan\nfrom wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS\nfrom wan.distributed.util import init_distributed_group\nfrom wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander\nfrom wan.utils.utils import merge_video_audio, save_video, str2bool\n\n\nEXAMPLE_PROMPT = {\n    \"t2v-A14B\": {\n        \"prompt\":\n            \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\",\n    },\n    \"i2v-A14B\": {\n        \"prompt\":\n            \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\",\n        \"image\":\n            \"examples/i2v_input.JPG\",\n    },\n    \"ti2v-5B\": {\n        \"prompt\":\n            \"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.\",\n    },\n    \"animate-14B\": {\n        \"prompt\": \"视频中的人在做动作\",\n        \"video\": \"\",\n        \"pose\": \"\",\n        \"mask\": \"\",\n    },\n    \"s2v-14B\": {\n        \"prompt\":\n            \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\",\n        \"image\":\n            \"examples/i2v_input.JPG\",\n        \"audio\":\n            \"examples/talk.wav\",\n        \"tts_prompt_audio\":\n            \"examples/zero_shot_prompt.wav\",\n        \"tts_prompt_text\":\n            \"希望你以后能够做的比我还好呦。\",\n        \"tts_text\":\n            \"收到好友从远方寄来的生日礼物，那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐，笑容如花儿般绽放。\"\n    },\n}\n\n\ndef _validate_args(args):\n    # Basic check\n    assert args.ckpt_dir is not None, \"Please specify the checkpoint directory.\"\n    assert args.task in WAN_CONFIGS, f\"Unsupport task: {args.task}\"\n    assert args.task in EXAMPLE_PROMPT, f\"Unsupport task: {args.task}\"\n\n    if args.prompt is None:\n        args.prompt = EXAMPLE_PROMPT[args.task][\"prompt\"]\n    if args.image is None and \"image\" in EXAMPLE_PROMPT[args.task]:\n        args.image = EXAMPLE_PROMPT[args.task][\"image\"]\n    if args.audio is None and args.enable_tts is False and \"audio\" in EXAMPLE_PROMPT[args.task]:\n        args.audio = EXAMPLE_PROMPT[args.task][\"audio\"]\n    if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and \"audio\" in EXAMPLE_PROMPT[args.task]:\n        args.tts_prompt_audio = EXAMPLE_PROMPT[args.task][\"tts_prompt_audio\"]\n        args.tts_prompt_text = EXAMPLE_PROMPT[args.task][\"tts_prompt_text\"]\n        args.tts_text = EXAMPLE_PROMPT[args.task][\"tts_text\"]\n\n    if args.task == \"i2v-A14B\":\n        assert args.image is not None, \"Please specify the image path for i2v.\"\n\n    cfg = WAN_CONFIGS[args.task]\n\n    if args.sample_steps is None:\n        args.sample_steps = cfg.sample_steps\n\n    if args.sample_shift is None:\n        args.sample_shift = cfg.sample_shift\n\n    if args.sample_guide_scale is None:\n        args.sample_guide_scale = cfg.sample_guide_scale\n\n    if args.frame_num is None:\n        args.frame_num = cfg.frame_num\n\n    args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(\n        0, sys.maxsize)\n    # Size check\n    if not 's2v' in args.task:\n        assert args.size in SUPPORTED_SIZES[\n            args.\n            task], f\"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}\"\n\n\ndef _parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Generate a image or video from a text prompt or image using Wan\"\n    )\n    parser.add_argument(\n        \"--task\",\n        type=str,\n        default=\"t2v-A14B\",\n        choices=list(WAN_CONFIGS.keys()),\n        help=\"The task to run.\")\n    parser.add_argument(\n        \"--size\",\n        type=str,\n        default=\"1280*720\",\n        choices=list(SIZE_CONFIGS.keys()),\n        help=\"The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.\"\n    )\n    parser.add_argument(\n        \"--frame_num\",\n        type=int,\n        default=None,\n        help=\"How many frames of video are generated. The number should be 4n+1\"\n    )\n    parser.add_argument(\n        \"--ckpt_dir\",\n        type=str,\n        default=None,\n        help=\"The path to the checkpoint directory.\")\n    parser.add_argument(\n        \"--offload_model\",\n        type=str2bool,\n        default=None,\n        help=\"Whether to offload the model to CPU after each model forward, reducing GPU memory usage.\"\n    )\n    parser.add_argument(\n        \"--ulysses_size\",\n        type=int,\n        default=1,\n        help=\"The size of the ulysses parallelism in DiT.\")\n    parser.add_argument(\n        \"--t5_fsdp\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use FSDP for T5.\")\n    parser.add_argument(\n        \"--t5_cpu\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to place T5 model on CPU.\")\n    parser.add_argument(\n        \"--dit_fsdp\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use FSDP for DiT.\")\n    parser.add_argument(\n        \"--save_file\",\n        type=str,\n        default=None,\n        help=\"The file to save the generated video to.\")\n    parser.add_argument(\n        \"--prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to generate the video from.\")\n    parser.add_argument(\n        \"--use_prompt_extend\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use prompt extend.\")\n    parser.add_argument(\n        \"--prompt_extend_method\",\n        type=str,\n        default=\"local_qwen\",\n        choices=[\"dashscope\", \"local_qwen\"],\n        help=\"The prompt extend method to use.\")\n    parser.add_argument(\n        \"--prompt_extend_model\",\n        type=str,\n        default=None,\n        help=\"The prompt extend model to use.\")\n    parser.add_argument(\n        \"--prompt_extend_target_lang\",\n        type=str,\n        default=\"zh\",\n        choices=[\"zh\", \"en\"],\n        help=\"The target language of prompt extend.\")\n    parser.add_argument(\n        \"--base_seed\",\n        type=int,\n        default=-1,\n        help=\"The seed to use for generating the video.\")\n    parser.add_argument(\n        \"--image\",\n        type=str,\n        default=None,\n        help=\"The image to generate the video from.\")\n    parser.add_argument(\n        \"--sample_solver\",\n        type=str,\n        default='unipc',\n        choices=['unipc', 'dpm++'],\n        help=\"The solver used to sample.\")\n    parser.add_argument(\n        \"--sample_steps\", type=int, default=None, help=\"The sampling steps.\")\n    parser.add_argument(\n        \"--sample_shift\",\n        type=float,\n        default=None,\n        help=\"Sampling shift factor for flow matching schedulers.\")\n    parser.add_argument(\n        \"--sample_guide_scale\",\n        type=float,\n        default=None,\n        help=\"Classifier free guidance scale.\")\n    parser.add_argument(\n        \"--convert_model_dtype\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to convert model paramerters dtype.\")\n\n    # animate\n    parser.add_argument(\n        \"--src_root_path\",\n        type=str,\n        default=None,\n        help=\"The file of the process output path. Default None.\")\n    parser.add_argument(\n        \"--refert_num\",\n        type=int,\n        default=77,\n        help=\"How many frames used for temporal guidance. Recommended to be 1 or 5.\"\n    )\n    parser.add_argument(\n        \"--replace_flag\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use replace.\")\n    parser.add_argument(\n        \"--use_relighting_lora\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use relighting lora.\")\n    \n    # following args only works for s2v\n    parser.add_argument(\n        \"--num_clip\",\n        type=int,\n        default=None,\n        help=\"Number of video clips to generate, the whole video will not exceed the length of audio.\"\n    )\n    parser.add_argument(\n        \"--audio\",\n        type=str,\n        default=None,\n        help=\"Path to the audio file, e.g. wav, mp3\")\n    parser.add_argument(\n        \"--enable_tts\",\n        action=\"store_true\",\n        default=False,\n        help=\"Use CosyVoice to synthesis audio\")\n    parser.add_argument(\n        \"--tts_prompt_audio\",\n        type=str,\n        default=None,\n        help=\"Path to the tts prompt audio file, e.g. wav, mp3. Must be greater than 16khz, and between 5s to 15s.\")\n    parser.add_argument(\n        \"--tts_prompt_text\",\n        type=str,\n        default=None,\n        help=\"Content to the tts prompt audio. If provided, must exactly match tts_prompt_audio\")\n    parser.add_argument(\n        \"--tts_text\",\n        type=str,\n        default=None,\n        help=\"Text wish to synthesize\")\n    parser.add_argument(\n        \"--pose_video\",\n        type=str,\n        default=None,\n        help=\"Provide Dw-pose sequence to do Pose Driven\")\n    parser.add_argument(\n        \"--start_from_ref\",\n        action=\"store_true\",\n        default=False,\n        help=\"whether set the reference image as the starting point for generation\"\n    )\n    parser.add_argument(\n        \"--infer_frames\",\n        type=int,\n        default=80,\n        help=\"Number of frames per clip, 48 or 80 or others (must be multiple of 4) for 14B s2v\"\n    )\n    args = parser.parse_args()\n    _validate_args(args)\n\n    return args\n\n\ndef _init_logging(rank):\n    # logging\n    if rank == 0:\n        # set format\n        logging.basicConfig(\n            level=logging.INFO,\n            format=\"[%(asctime)s] %(levelname)s: %(message)s\",\n            handlers=[logging.StreamHandler(stream=sys.stdout)])\n    else:\n        logging.basicConfig(level=logging.ERROR)\n\n\ndef generate(args):\n    rank = int(os.getenv(\"RANK\", 0))\n    world_size = int(os.getenv(\"WORLD_SIZE\", 1))\n    local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n    device = local_rank\n    _init_logging(rank)\n\n    if args.offload_model is None:\n        args.offload_model = False if world_size > 1 else True\n        logging.info(\n            f\"offload_model is not specified, set to {args.offload_model}.\")\n    if world_size > 1:\n        torch.cuda.set_device(local_rank)\n        dist.init_process_group(\n            backend=\"nccl\",\n            init_method=\"env://\",\n            rank=rank,\n            world_size=world_size)\n    else:\n        assert not (\n            args.t5_fsdp or args.dit_fsdp\n        ), f\"t5_fsdp and dit_fsdp are not supported in non-distributed environments.\"\n        assert not (\n            args.ulysses_size > 1\n        ), f\"sequence parallel are not supported in non-distributed environments.\"\n\n    if args.ulysses_size > 1:\n        assert args.ulysses_size == world_size, f\"The number of ulysses_size should be equal to the world size.\"\n        init_distributed_group()\n\n    if args.use_prompt_extend:\n        if args.prompt_extend_method == \"dashscope\":\n            prompt_expander = DashScopePromptExpander(\n                model_name=args.prompt_extend_model,\n                task=args.task,\n                is_vl=args.image is not None)\n        elif args.prompt_extend_method == \"local_qwen\":\n            prompt_expander = QwenPromptExpander(\n                model_name=args.prompt_extend_model,\n                task=args.task,\n                is_vl=args.image is not None,\n                device=rank)\n        else:\n            raise NotImplementedError(\n                f\"Unsupport prompt_extend_method: {args.prompt_extend_method}\")\n\n    cfg = WAN_CONFIGS[args.task]\n    if args.ulysses_size > 1:\n        assert cfg.num_heads % args.ulysses_size == 0, f\"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`.\"\n\n    logging.info(f\"Generation job args: {args}\")\n    logging.info(f\"Generation model config: {cfg}\")\n\n    if dist.is_initialized():\n        base_seed = [args.base_seed] if rank == 0 else [None]\n        dist.broadcast_object_list(base_seed, src=0)\n        args.base_seed = base_seed[0]\n\n    logging.info(f\"Input prompt: {args.prompt}\")\n    img = None\n    if args.image is not None:\n        img = Image.open(args.image).convert(\"RGB\")\n        logging.info(f\"Input image: {args.image}\")\n\n    # prompt extend\n    if args.use_prompt_extend:\n        logging.info(\"Extending prompt ...\")\n        if rank == 0:\n            prompt_output = prompt_expander(\n                args.prompt,\n                image=img,\n                tar_lang=args.prompt_extend_target_lang,\n                seed=args.base_seed)\n            if prompt_output.status == False:\n                logging.info(\n                    f\"Extending prompt failed: {prompt_output.message}\")\n                logging.info(\"Falling back to original prompt.\")\n                input_prompt = args.prompt\n            else:\n                input_prompt = prompt_output.prompt\n            input_prompt = [input_prompt]\n        else:\n            input_prompt = [None]\n        if dist.is_initialized():\n            dist.broadcast_object_list(input_prompt, src=0)\n        args.prompt = input_prompt[0]\n        logging.info(f\"Extended prompt: {args.prompt}\")\n\n    if \"t2v\" in args.task:\n        logging.info(\"Creating WanT2V pipeline.\")\n        wan_t2v = wan.WanT2V(\n            config=cfg,\n            checkpoint_dir=args.ckpt_dir,\n            device_id=device,\n            rank=rank,\n            t5_fsdp=args.t5_fsdp,\n            dit_fsdp=args.dit_fsdp,\n            use_sp=(args.ulysses_size > 1),\n            t5_cpu=args.t5_cpu,\n            convert_model_dtype=args.convert_model_dtype,\n        )\n\n        logging.info(f\"Generating video ...\")\n        video = wan_t2v.generate(\n            args.prompt,\n            size=SIZE_CONFIGS[args.size],\n            frame_num=args.frame_num,\n            shift=args.sample_shift,\n            sample_solver=args.sample_solver,\n            sampling_steps=args.sample_steps,\n            guide_scale=args.sample_guide_scale,\n            seed=args.base_seed,\n            offload_model=args.offload_model)\n    elif \"ti2v\" in args.task:\n        logging.info(\"Creating WanTI2V pipeline.\")\n        wan_ti2v = wan.WanTI2V(\n            config=cfg,\n            checkpoint_dir=args.ckpt_dir,\n            device_id=device,\n            rank=rank,\n            t5_fsdp=args.t5_fsdp,\n            dit_fsdp=args.dit_fsdp,\n            use_sp=(args.ulysses_size > 1),\n            t5_cpu=args.t5_cpu,\n            convert_model_dtype=args.convert_model_dtype,\n        )\n\n        logging.info(f\"Generating video ...\")\n        video = wan_ti2v.generate(\n            args.prompt,\n            img=img,\n            size=SIZE_CONFIGS[args.size],\n            max_area=MAX_AREA_CONFIGS[args.size],\n            frame_num=args.frame_num,\n            shift=args.sample_shift,\n            sample_solver=args.sample_solver,\n            sampling_steps=args.sample_steps,\n            guide_scale=args.sample_guide_scale,\n            seed=args.base_seed,\n            offload_model=args.offload_model)\n    elif \"animate\" in args.task:\n        logging.info(\"Creating Wan-Animate pipeline.\")\n        wan_animate = wan.WanAnimate(\n            config=cfg,\n            checkpoint_dir=args.ckpt_dir,\n            device_id=device,\n            rank=rank,\n            t5_fsdp=args.t5_fsdp,\n            dit_fsdp=args.dit_fsdp,\n            use_sp=(args.ulysses_size > 1),\n            t5_cpu=args.t5_cpu,\n            convert_model_dtype=args.convert_model_dtype,\n            use_relighting_lora=args.use_relighting_lora\n        )\n\n        logging.info(f\"Generating video ...\")\n        video = wan_animate.generate(\n            src_root_path=args.src_root_path,\n            replace_flag=args.replace_flag,\n            refert_num = args.refert_num,\n            clip_len=args.frame_num,\n            shift=args.sample_shift,\n            sample_solver=args.sample_solver,\n            sampling_steps=args.sample_steps,\n            guide_scale=args.sample_guide_scale,\n            seed=args.base_seed,\n            offload_model=args.offload_model)\n    elif \"s2v\" in args.task:\n        logging.info(\"Creating WanS2V pipeline.\")\n        wan_s2v = wan.WanS2V(\n            config=cfg,\n            checkpoint_dir=args.ckpt_dir,\n            device_id=device,\n            rank=rank,\n            t5_fsdp=args.t5_fsdp,\n            dit_fsdp=args.dit_fsdp,\n            use_sp=(args.ulysses_size > 1),\n            t5_cpu=args.t5_cpu,\n            convert_model_dtype=args.convert_model_dtype,\n        )\n        logging.info(f\"Generating video ...\")\n        video = wan_s2v.generate(\n            input_prompt=args.prompt,\n            ref_image_path=args.image,\n            audio_path=args.audio,\n            enable_tts=args.enable_tts,\n            tts_prompt_audio=args.tts_prompt_audio,\n            tts_prompt_text=args.tts_prompt_text,\n            tts_text=args.tts_text,\n            num_repeat=args.num_clip,\n            pose_video=args.pose_video,\n            max_area=MAX_AREA_CONFIGS[args.size],\n            infer_frames=args.infer_frames,\n            shift=args.sample_shift,\n            sample_solver=args.sample_solver,\n            sampling_steps=args.sample_steps,\n            guide_scale=args.sample_guide_scale,\n            seed=args.base_seed,\n            offload_model=args.offload_model,\n            init_first_frame=args.start_from_ref,\n        )\n    else:\n        logging.info(\"Creating WanI2V pipeline.\")\n        wan_i2v = wan.WanI2V(\n            config=cfg,\n            checkpoint_dir=args.ckpt_dir,\n            device_id=device,\n            rank=rank,\n            t5_fsdp=args.t5_fsdp,\n            dit_fsdp=args.dit_fsdp,\n            use_sp=(args.ulysses_size > 1),\n            t5_cpu=args.t5_cpu,\n            convert_model_dtype=args.convert_model_dtype,\n        )\n        logging.info(\"Generating video ...\")\n        video = wan_i2v.generate(\n            args.prompt,\n            img,\n            max_area=MAX_AREA_CONFIGS[args.size],\n            frame_num=args.frame_num,\n            shift=args.sample_shift,\n            sample_solver=args.sample_solver,\n            sampling_steps=args.sample_steps,\n            guide_scale=args.sample_guide_scale,\n            seed=args.base_seed,\n            offload_model=args.offload_model)\n\n    if rank == 0:\n        if args.save_file is None:\n            formatted_time = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n            formatted_prompt = args.prompt.replace(\" \", \"_\").replace(\"/\",\n                                                                     \"_\")[:50]\n            suffix = '.mp4'\n            args.save_file = f\"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}\" + suffix\n\n        logging.info(f\"Saving generated video to {args.save_file}\")\n        save_video(\n            tensor=video[None],\n            save_file=args.save_file,\n            fps=cfg.sample_fps,\n            nrow=1,\n            normalize=True,\n            value_range=(-1, 1))\n        if \"s2v\" in args.task:\n            if args.enable_tts is False:\n                merge_video_audio(video_path=args.save_file, audio_path=args.audio)\n            else:\n                merge_video_audio(video_path=args.save_file, audio_path=\"tts.wav\")\n    del video\n\n    torch.cuda.synchronize()\n    if dist.is_initialized():\n        dist.barrier()\n        dist.destroy_process_group()\n\n    logging.info(\"Finished.\")\n\n\nif __name__ == \"__main__\":\n    args = _parse_args()\n    generate(args)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"wan\"\nversion = \"2.2.0\"\ndescription = \"Wan: Open and Advanced Large-Scale Video Generative Models\"\nauthors = [\n    { name = \"Wan Team\", email = \"wan.ai@alibabacloud.com\" }\n]\nlicense = { file = \"LICENSE.txt\" }\nreadme = \"README.md\"\nrequires-python = \">=3.10,<4.0\"\ndependencies = [\n    \"torch>=2.4.0\",\n    \"torchvision>=0.19.0\",\n    \"opencv-python>=4.9.0.80\",\n    \"diffusers>=0.31.0\",\n    \"transformers>=4.49.0\",\n    \"tokenizers>=0.20.3\",\n    \"accelerate>=1.1.1\",\n    \"tqdm\",\n    \"imageio\",\n    \"easydict\",\n    \"ftfy\",\n    \"dashscope\",\n    \"imageio-ffmpeg\",\n    \"flash_attn\",\n    \"numpy>=1.23.5,<2\"\n]\n\n[project.optional-dependencies]\ndev = [\n    \"pytest\",\n    \"black\",\n    \"flake8\",\n    \"isort\",\n    \"mypy\",\n    \"huggingface-hub[cli]\"\n]\n\n[project.urls]\nhomepage = \"https://wanxai.com\"\ndocumentation = \"https://github.com/Wan-Video/Wan2.2\"\nrepository = \"https://github.com/Wan-Video/Wan2.2\"\nhuggingface = \"https://huggingface.co/Wan-AI/\"\nmodelscope = \"https://modelscope.cn/organization/Wan-AI\"\ndiscord = \"https://discord.gg/p5XbdQV7\"\n\n[tool.setuptools]\npackages = [\"wan\"]\n\n[tool.setuptools.package-data]\n\"wan\" = [\"**/*.py\"]\n\n[tool.black]\nline-length = 88\n\n[tool.isort]\nprofile = \"black\"\n\n[tool.mypy]\nstrict = true\n\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch>=2.4.0\ntorchvision>=0.19.0\ntorchaudio\nopencv-python>=4.9.0.80\ndiffusers>=0.31.0\ntransformers>=4.49.0,<=4.51.3\ntokenizers>=0.20.3\naccelerate>=1.1.1\ntqdm\nimageio[ffmpeg]\neasydict\nftfy\ndashscope\nimageio-ffmpeg\nflash_attn\nnumpy>=1.23.5,<2"
  },
  {
    "path": "requirements_animate.txt",
    "content": "decord\npeft\nonnxruntime\npandas\nmatplotlib\n-e  git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2 \nloguru\nsentencepiece"
  },
  {
    "path": "requirements_s2v.txt",
    "content": "openai-whisper\nHyperPyYAML\nonnxruntime\ninflect\nwetext\nomegaconf\nconformer\nhydra-core\nlightning\nrich\ngdown\nmatplotlib\nwget\npyarrow\npyworld\nlibrosa\ndecord\nmodelscope\nGitPython"
  },
  {
    "path": "tests/README.md",
    "content": "\nPut all your models (Wan2.2-T2V-A14B, Wan2.2-I2V-A14B, Wan2.2-TI2V-5B) in a folder and specify the max GPU number you want to use.\n\n```bash\nbash ./tests/test.sh <local model dir> <gpu number>\n```\n"
  },
  {
    "path": "tests/test.sh",
    "content": "#!/bin/bash\nset -x\n\nunset NCCL_DEBUG\n\nif [ \"$#\" -eq 2 ]; then\n  MODEL_DIR=$(realpath \"$1\")\n  GPUS=$2\nelse\n  echo \"Usage: $0 <local model dir> <gpu number>\"\n  exit 1\nfi\n\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nREPO_ROOT=\"$(dirname \"$SCRIPT_DIR\")\"\ncd \"$REPO_ROOT\" || exit 1\n\nPY_FILE=./generate.py\n\n\nfunction t2v_A14B() {\n    CKPT_DIR=\"$MODEL_DIR/Wan2.2-T2V-A14B\"\n\n    # # 1-GPU Test\n    # echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B 1-GPU Test: \"\n    # python $PY_FILE --task t2v-A14B --size 480*832 --ckpt_dir $CKPT_DIR\n\n    # Multiple GPU Test\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU Test: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 1280*720 --dit_fsdp --t5_fsdp --ulysses_size $GPUS\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_A14B Multiple GPU, prompt extend local_qwen: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-A14B --ckpt_dir $CKPT_DIR --size 480*832 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model \"Qwen/Qwen2.5-3B-Instruct\" --prompt_extend_target_lang \"en\"\n}\n\n\nfunction i2v_A14B() {\n    CKPT_DIR=\"$MODEL_DIR/Wan2.2-I2V-A14B\"\n\n    # echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: \"\n    # python $PY_FILE --task i2v-A14B --size 832*480 --ckpt_dir $CKPT_DIR\n\n    # Multiple GPU Test\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model \"Qwen/Qwen2.5-VL-3B-Instruct\" --prompt_extend_target_lang \"en\"\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 1280*720 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model \"Qwen/Qwen2.5-VL-3B-Instruct\" --prompt_extend_target_lang \"en\"\n\n    if [ -n \"${DASH_API_KEY+x}\" ]; then\n        echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: \"\n        torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-A14B --ckpt_dir $CKPT_DIR --size 480*832 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method \"dashscope\"\n    else\n        echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test.\"\n    fi\n}\n\nfunction ti2v_5B() {\n    CKPT_DIR=\"$MODEL_DIR/Wan2.2-TI2V-5B\"\n\n    # # 1-GPU Test\n    # echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v 1-GPU Test: \"\n    # python $PY_FILE --task ti2v-5B --size 1280*704 --ckpt_dir $CKPT_DIR\n\n    # Multiple GPU Test\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v Multiple GPU Test: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 1280*704 --dit_fsdp --t5_fsdp --ulysses_size $GPUS\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B t2v Multiple GPU, prompt extend local_qwen: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 704*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model \"Qwen/Qwen2.5-3B-Instruct\" --prompt_extend_target_lang \"en\"\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B i2v Multiple GPU Test: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 704*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\" --image \"examples/i2v_input.JPG\"\n\n    echo -e \"\\n\\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> ti2v_5B i2v Multiple GPU, prompt extend local_qwen: \"\n    torchrun --nproc_per_node=$GPUS $PY_FILE --task ti2v-5B --ckpt_dir $CKPT_DIR --size 1280*704 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model \"Qwen/Qwen2.5-3B-Instruct\" --prompt_extend_target_lang 'en' --prompt \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\" --image \"examples/i2v_input.JPG\"\n\n}\n\nt2v_A14B\ni2v_A14B\nti2v_5B\n"
  },
  {
    "path": "wan/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom . import configs, distributed, modules\nfrom .image2video import WanI2V\nfrom .speech2video import WanS2V\nfrom .text2video import WanT2V\nfrom .textimage2video import WanTI2V\nfrom .animate import WanAnimate"
  },
  {
    "path": "wan/animate.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\nimport math\nimport os\nimport cv2\nimport types\nfrom copy import deepcopy\nfrom functools import partial\nfrom einops import rearrange\nimport numpy as np\nimport torch\n\nimport torch.distributed as dist\nfrom peft import set_peft_model_state_dict\nfrom decord import VideoReader\nfrom tqdm import tqdm\nimport torch.nn.functional as F\nfrom .distributed.fsdp import shard_model\nfrom .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward\nfrom .distributed.util import get_world_size\n\nfrom .modules.animate import WanAnimateModel\nfrom .modules.animate import CLIPModel\nfrom .modules.t5 import T5EncoderModel\nfrom .modules.vae2_1 import Wan2_1_VAE\nfrom .modules.animate.animate_utils import TensorList, get_loraconfig\nfrom .utils.fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler\n\n\n\nclass WanAnimate:\n\n    def __init__(\n        self,\n        config,\n        checkpoint_dir,\n        device_id=0,\n        rank=0,\n        t5_fsdp=False,\n        dit_fsdp=False,\n        use_sp=False,\n        t5_cpu=False,\n        init_on_cpu=True,\n        convert_model_dtype=False,\n        use_relighting_lora=False\n    ):\n        r\"\"\"\n        Initializes the generation model components.\n\n        Args:\n            config (EasyDict):\n                Object containing model parameters initialized from config.py\n            checkpoint_dir (`str`):\n                Path to directory containing model checkpoints\n            device_id (`int`,  *optional*, defaults to 0):\n                Id of target GPU device\n            rank (`int`,  *optional*, defaults to 0):\n                Process rank for distributed training\n            t5_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for T5 model\n            dit_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for DiT model\n            use_sp (`bool`, *optional*, defaults to False):\n                Enable distribution strategy of sequence parallel.\n            t5_cpu (`bool`, *optional*, defaults to False):\n                Whether to place T5 model on CPU. Only works without t5_fsdp.\n            init_on_cpu (`bool`, *optional*, defaults to True):\n                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.\n            convert_model_dtype (`bool`, *optional*, defaults to False):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n            use_relighting_lora (`bool`, *optional*, defaults to False):\n               Whether to use relighting lora for character replacement. \n        \"\"\"\n        self.device = torch.device(f\"cuda:{device_id}\")\n        self.config = config\n        self.rank = rank\n        self.t5_cpu = t5_cpu\n        self.init_on_cpu = init_on_cpu\n\n        self.num_train_timesteps = config.num_train_timesteps\n        self.param_dtype = config.param_dtype\n\n        if t5_fsdp or dit_fsdp or use_sp:\n            self.init_on_cpu = False\n\n        shard_fn = partial(shard_model, device_id=device_id)\n        self.text_encoder = T5EncoderModel(\n            text_len=config.text_len,\n            dtype=config.t5_dtype,\n            device=torch.device('cpu'),\n            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),\n            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),\n            shard_fn=shard_fn if t5_fsdp else None,\n        )\n\n        self.clip = CLIPModel(\n            dtype=torch.float16,\n            device=self.device,\n            checkpoint_path=os.path.join(checkpoint_dir,\n                                         config.clip_checkpoint),\n            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))\n\n        self.vae = Wan2_1_VAE(\n            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),\n            device=self.device)\n\n        logging.info(f\"Creating WanAnimate from {checkpoint_dir}\")\n\n        if not dit_fsdp:\n            self.noise_model = WanAnimateModel.from_pretrained(\n                checkpoint_dir,\n                torch_dtype=self.param_dtype,\n                device_map=self.device)\n        else:\n            self.noise_model = WanAnimateModel.from_pretrained(\n                checkpoint_dir, torch_dtype=self.param_dtype)\n\n        self.noise_model = self._configure_model(\n            model=self.noise_model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype,\n            use_lora=use_relighting_lora,\n            checkpoint_dir=checkpoint_dir,\n            config=config\n            )\n\n        if use_sp:\n            self.sp_size = get_world_size()\n        else:\n            self.sp_size = 1\n\n        self.sample_neg_prompt = config.sample_neg_prompt\n        self.sample_prompt = config.prompt\n\n\n    def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,\n                         convert_model_dtype, use_lora, checkpoint_dir, config):\n        \"\"\"\n        Configures a model object. This includes setting evaluation modes,\n        applying distributed parallel strategy, and handling device placement.\n\n        Args:\n            model (torch.nn.Module):\n                The model instance to configure.\n            use_sp (`bool`):\n                Enable distribution strategy of sequence parallel.\n            dit_fsdp (`bool`):\n                Enable FSDP sharding for DiT model.\n            shard_fn (callable):\n                The function to apply FSDP sharding.\n            convert_model_dtype (`bool`):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n\n        Returns:\n            torch.nn.Module:\n                The configured model.\n        \"\"\"\n        model.eval().requires_grad_(False)\n\n        if use_sp:\n            for block in model.blocks:\n                block.self_attn.forward = types.MethodType(\n                    sp_attn_forward, block.self_attn)\n\n            model.use_context_parallel = True\n\n        if dist.is_initialized():\n            dist.barrier()\n\n        if use_lora:\n            logging.info(\"Loading Relighting Lora. \")\n            lora_config = get_loraconfig(\n                transformer=model,\n                rank=128,\n                alpha=128\n            )\n            model.add_adapter(lora_config)\n            lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint)\n            peft_state_dict = torch.load(lora_path)[\"state_dict\"]\n            set_peft_model_state_dict(model, peft_state_dict)\n\n        if dit_fsdp:\n            model = shard_fn(model, use_lora=use_lora)\n        else:\n            if convert_model_dtype:\n                model.to(self.param_dtype)\n            if not self.init_on_cpu:\n                model.to(self.device)\n\n        return model\n\n    def inputs_padding(self, array, target_len):\n        idx = 0\n        flip = False\n        target_array = []\n        while len(target_array) < target_len:\n            target_array.append(deepcopy(array[idx]))\n            if flip:\n                idx -= 1\n            else:\n                idx += 1\n            if idx == 0 or idx == len(array) - 1:\n                flip = not flip\n        return target_array[:target_len]\n\n    def get_valid_len(self, real_len, clip_len=81, overlap=1):\n        real_clip_len = clip_len - overlap\n        last_clip_num = (real_len - overlap) % real_clip_len\n        if last_clip_num == 0:\n            extra = 0\n        else:\n            extra = real_clip_len - last_clip_num\n        target_len = real_len + extra\n        return target_len\n\n\n    def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=\"cuda\"):\n        if mask_pixel_values is None:\n            msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)\n        else:\n            msk = mask_pixel_values.clone()\n        msk[:, :mask_len] = 1\n        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)\n        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)\n        msk = msk.transpose(1, 2)[0]\n        return msk\n\n    def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):\n        ori_height = img_ori.shape[0]\n        ori_width = img_ori.shape[1]\n        channel = img_ori.shape[2]\n\n        img_pad = np.zeros((height, width, channel))\n        if channel == 1:\n            img_pad[:, :, 0] = padding_color[0]\n        else:\n            img_pad[:, :, 0] = padding_color[0]\n            img_pad[:, :, 1] = padding_color[1]\n            img_pad[:, :, 2] = padding_color[2]\n\n        if (ori_height / ori_width) > (height / width):\n            new_width = int(height / ori_height * ori_width)\n            img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)\n            padding = int((width - new_width) / 2)\n            if len(img.shape) == 2:\n                img = img[:, :, np.newaxis]  \n            img_pad[:, padding: padding + new_width, :] = img\n        else:\n            new_height = int(width / ori_width * ori_height)\n            img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)\n            padding = int((height - new_height) / 2)\n            if len(img.shape) == 2:\n                img = img[:, :, np.newaxis]  \n            img_pad[padding: padding + new_height, :, :] = img\n\n        img_pad = np.uint8(img_pad)\n\n        return img_pad\n\n    def prepare_source(self, src_pose_path, src_face_path, src_ref_path):\n        pose_video_reader = VideoReader(src_pose_path)\n        pose_len = len(pose_video_reader)\n        pose_idxs = list(range(pose_len))\n        cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()\n\n        face_video_reader = VideoReader(src_face_path)\n        face_len = len(face_video_reader)\n        face_idxs = list(range(face_len))\n        face_images = face_video_reader.get_batch(face_idxs).asnumpy()\n        height, width = cond_images[0].shape[:2]\n        refer_images = cv2.imread(src_ref_path)[..., ::-1]\n        refer_images = self.padding_resize(refer_images, height=height, width=width)\n        return cond_images, face_images, refer_images\n    \n    def prepare_source_for_replace(self, src_bg_path, src_mask_path):\n        bg_video_reader = VideoReader(src_bg_path)\n        bg_len = len(bg_video_reader)\n        bg_idxs = list(range(bg_len))\n        bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()\n\n        mask_video_reader = VideoReader(src_mask_path)\n        mask_len = len(mask_video_reader)\n        mask_idxs = list(range(mask_len))\n        mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()\n        mask_images = mask_images[:, :, :, 0] / 255\n        return bg_images, mask_images\n\n    def generate(\n        self,\n        src_root_path,\n        replace_flag=False,\n        clip_len=77,\n        refert_num=1,\n        shift=5.0,\n        sample_solver='dpm++',\n        sampling_steps=20,\n        guide_scale=1,\n        input_prompt=\"\",\n        n_prompt=\"\",\n        seed=-1,\n        offload_model=True,\n    ):\n        r\"\"\"\n        Generates video frames from input image using diffusion process.\n\n        Args:\n            src_root_path ('str'):\n                Process output path\n            replace_flag (`bool`, *optional*, defaults to False):\n                Whether to use character replace.\n            clip_len (`int`, *optional*, defaults to 77):\n                How many frames to generate per clips. The number should be 4n+1\n            refert_num (`int`, *optional*, defaults to 1):\n                How many frames used for temporal guidance. Recommended to be 1 or 5.\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. \n            sample_solver (`str`, *optional*, defaults to 'dpm++'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 20):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):\n                Classifier-free guidance scale. We only use it for expression control. \n                In most cases, it's not necessary and faster generation can be achieved without it. \n                When expression adjustments are needed, you may consider using this feature.\n            input_prompt (`str`):\n                Text prompt for content generation. We don't recommend custom prompts (although they work)\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N, H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames\n                - H: Frame height \n                - W: Frame width \n        \"\"\"\n        assert refert_num == 1 or refert_num == 5, \"refert_num should be 1 or 5.\"\n\n        seed_g = torch.Generator(device=self.device)\n        seed_g.manual_seed(seed)\n\n        if n_prompt == \"\":\n            n_prompt = self.sample_neg_prompt\n\n        if input_prompt == \"\":\n            input_prompt = self.sample_prompt\n\n        src_pose_path = os.path.join(src_root_path, \"src_pose.mp4\")\n        src_face_path = os.path.join(src_root_path, \"src_face.mp4\")\n        src_ref_path = os.path.join(src_root_path, \"src_ref.png\")\n\n        cond_images, face_images, refer_images = self.prepare_source(src_pose_path=src_pose_path, src_face_path=src_face_path, src_ref_path=src_ref_path)\n        \n        if not self.t5_cpu:\n            self.text_encoder.model.to(self.device)\n            context = self.text_encoder([input_prompt], self.device)\n            context_null = self.text_encoder([n_prompt], self.device)\n            if offload_model:\n                self.text_encoder.model.cpu()\n        else:\n            context = self.text_encoder([input_prompt], torch.device('cpu'))\n            context_null = self.text_encoder([n_prompt], torch.device('cpu'))\n            context = [t.to(self.device) for t in context]\n            context_null = [t.to(self.device) for t in context_null]\n\n        real_frame_len = len(cond_images)\n        target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)\n        logging.info('real frames: {} target frames: {}'.format(real_frame_len, target_len))\n        cond_images = self.inputs_padding(cond_images, target_len)\n        face_images = self.inputs_padding(face_images, target_len)\n        \n        if replace_flag:\n            src_bg_path = os.path.join(src_root_path, \"src_bg.mp4\")\n            src_mask_path = os.path.join(src_root_path, \"src_mask.mp4\")\n            bg_images, mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)\n            bg_images = self.inputs_padding(bg_images, target_len)\n            mask_images = self.inputs_padding(mask_images, target_len)\n\n        height, width = refer_images.shape[:2]\n        start = 0\n        end = clip_len\n        all_out_frames = []\n        while True:\n            if start + refert_num >= len(cond_images):\n                break\n\n            if start == 0:\n                mask_reft_len = 0\n            else:\n                mask_reft_len = refert_num\n\n            batch = {\n                        \"conditioning_pixel_values\": torch.zeros(1, 3, clip_len, height, width),\n                        \"bg_pixel_values\": torch.zeros(1, 3, clip_len, height, width),\n                        \"mask_pixel_values\": torch.zeros(1, 1, clip_len, height, width),\n                        \"face_pixel_values\": torch.zeros(1, 3, clip_len, 512, 512),\n                        \"refer_pixel_values\": torch.zeros(1, 3, height, width),\n                        \"refer_t_pixel_values\": torch.zeros(refert_num, 3, height, width)\n                    }   \n\n            batch[\"conditioning_pixel_values\"] = rearrange(\n                torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1),\n                \"t h w c -> 1 c t h w\",\n            )\n            batch[\"face_pixel_values\"] = rearrange(\n                torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1),\n                \"t h w c -> 1 c t h w\",\n            )\n\n            batch[\"refer_pixel_values\"] = rearrange(\n                torch.tensor(refer_images / 127.5 - 1), \"h w c -> 1 c h w\"\n            )\n\n            if start > 0:\n                batch[\"refer_t_pixel_values\"] = rearrange(\n                    out_frames[0, :, -refert_num:].clone().detach(),\n                    \"c t h w -> t c h w\",\n                )\n\n            batch[\"refer_t_pixel_values\"] = rearrange(batch[\"refer_t_pixel_values\"],\n                                            \"t c h w -> 1 c t h w\",\n                                            )\n\n            if replace_flag:\n                batch[\"bg_pixel_values\"] = rearrange(\n                    torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1),\n                    \"t h w c -> 1 c t h w\",\n                )\n\n                batch[\"mask_pixel_values\"] = rearrange(\n                    torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]),\n                    \"t h w c -> 1 t c h w\",\n                )\n                \n\n            for key, value in batch.items():\n                if isinstance(value, torch.Tensor):\n                    batch[key] = value.to(device=self.device, dtype=torch.bfloat16)\n\n            ref_pixel_values = batch[\"refer_pixel_values\"]\n            refer_t_pixel_values = batch[\"refer_t_pixel_values\"]\n            conditioning_pixel_values = batch[\"conditioning_pixel_values\"]\n            face_pixel_values = batch[\"face_pixel_values\"]\n\n            B, _, H, W = ref_pixel_values.shape\n            T = clip_len\n            lat_h = H // 8\n            lat_w = W // 8\n            lat_t = T // 4 + 1\n            target_shape = [lat_t + 1, lat_h, lat_w]\n            noise = [\n                torch.randn(\n                    16,\n                    target_shape[0],\n                    target_shape[1],\n                    target_shape[2],\n                    dtype=torch.float32,\n                    device=self.device,\n                    generator=seed_g,\n                )\n            ]\n        \n            max_seq_len = int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size\n            if max_seq_len % self.sp_size != 0:\n                raise ValueError(f\"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}\")\n\n            with (\n                torch.autocast(device_type=str(self.device), dtype=torch.bfloat16, enabled=True),\n                torch.no_grad()\n            ):\n                if sample_solver == 'unipc':\n                    sample_scheduler = FlowUniPCMultistepScheduler(\n                        num_train_timesteps=self.num_train_timesteps,\n                        shift=1,\n                        use_dynamic_shifting=False)\n                    sample_scheduler.set_timesteps(\n                        sampling_steps, device=self.device, shift=shift)\n                    timesteps = sample_scheduler.timesteps\n                elif sample_solver == 'dpm++':\n                    sample_scheduler = FlowDPMSolverMultistepScheduler(\n                        num_train_timesteps=self.num_train_timesteps,\n                        shift=1,\n                        use_dynamic_shifting=False)\n                    sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)\n                    timesteps, _ = retrieve_timesteps(\n                        sample_scheduler,\n                        device=self.device,\n                        sigmas=sampling_sigmas)\n                else:\n                    raise NotImplementedError(\"Unsupported solver.\")\n\n                latents = noise\n\n                pose_latents_no_ref =  self.vae.encode(conditioning_pixel_values.to(torch.bfloat16))\n                pose_latents_no_ref = torch.stack(pose_latents_no_ref)\n                pose_latents = torch.cat([pose_latents_no_ref], dim=2)\n\n                ref_pixel_values = rearrange(ref_pixel_values, \"t c h w -> 1 c t h w\")\n                ref_latents =  self.vae.encode(ref_pixel_values.to(torch.bfloat16))\n                ref_latents = torch.stack(ref_latents)\n\n                mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device)\n                y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=self.device)\n\n                img = ref_pixel_values[0, :, 0]\n                clip_context = self.clip.visual([img[:, None, :, :]]).to(dtype=torch.bfloat16, device=self.device)\n\n                if mask_reft_len > 0:\n                    if replace_flag:\n                        bg_pixel_values = batch[\"bg_pixel_values\"]\n                        y_reft = self.vae.encode(\n                            [\n                                torch.concat([refer_t_pixel_values[0, :, :mask_reft_len], bg_pixel_values[0, :, mask_reft_len:]], dim=1).to(self.device)\n                            ]\n                        )[0]\n                        mask_pixel_values = 1 - batch[\"mask_pixel_values\"]\n                        mask_pixel_values = rearrange(mask_pixel_values, \"b t c h w -> (b t) c h w\")\n                        mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')\n                        mask_pixel_values = rearrange(mask_pixel_values, \"(b t) c h w -> b t c h w\", b=1)[:,:,0]\n                        msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)\n                    else:\n                        y_reft = self.vae.encode(\n                            [\n                                torch.concat(\n                                    [\n                                        torch.nn.functional.interpolate(refer_t_pixel_values[0, :, :mask_reft_len].cpu(),\n                                                                        size=(H, W), mode=\"bicubic\"),\n                                        torch.zeros(3, T - mask_reft_len, H, W),\n                                    ],\n                                    dim=1,\n                                ).to(self.device)\n                            ]\n                        )[0]\n                        msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)\n                else:\n                    if replace_flag:\n                        bg_pixel_values = batch[\"bg_pixel_values\"]\n                        mask_pixel_values = 1 - batch[\"mask_pixel_values\"]\n                        mask_pixel_values = rearrange(mask_pixel_values, \"b t c h w -> (b t) c h w\")\n                        mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')\n                        mask_pixel_values = rearrange(mask_pixel_values, \"(b t) c h w -> b t c h w\", b=1)[:,:,0]\n                        y_reft = self.vae.encode(\n                            [\n                                torch.concat(\n                                    [\n                                        bg_pixel_values[0],\n                                    ],\n                                    dim=1,\n                                ).to(self.device)\n                            ]\n                        )[0]\n                        msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)\n                    else:\n                        y_reft = self.vae.encode(\n                            [\n                                torch.concat(\n                                    [\n                                        torch.zeros(3, T - mask_reft_len, H, W),\n                                    ],\n                                    dim=1,\n                                ).to(self.device)\n                            ]\n                        )[0]\n                        msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)\n\n                y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=self.device)\n                y = torch.concat([y_ref, y_reft], dim=1)\n\n                arg_c = {\n                    \"context\": context, \n                    \"seq_len\": max_seq_len,\n                    \"clip_fea\": clip_context.to(dtype=torch.bfloat16, device=self.device),\n                    \"y\": [y],\n                    \"pose_latents\": pose_latents,\n                    \"face_pixel_values\": face_pixel_values,\n                }\n\n                if guide_scale > 1:\n                    face_pixel_values_uncond = face_pixel_values * 0 - 1\n                    arg_null = {\n                        \"context\": context_null,\n                        \"seq_len\": max_seq_len,\n                        \"clip_fea\": clip_context.to(dtype=torch.bfloat16, device=self.device),\n                        \"y\": [y],\n                        \"pose_latents\": pose_latents,\n                        \"face_pixel_values\": face_pixel_values_uncond,\n                    }\n\n                for i, t in enumerate(tqdm(timesteps)):\n                    latent_model_input = latents\n                    timestep = [t]\n\n                    timestep = torch.stack(timestep)\n\n                    noise_pred_cond = TensorList(\n                         self.noise_model(TensorList(latent_model_input), t=timestep, **arg_c)\n                    )\n\n                    if guide_scale > 1:\n                        noise_pred_uncond = TensorList(\n                             self.noise_model(\n                                TensorList(latent_model_input), t=timestep, **arg_null\n                            )\n                        )\n                        noise_pred = noise_pred_uncond + guide_scale * (\n                            noise_pred_cond - noise_pred_uncond\n                        )\n                    else:\n                        noise_pred = noise_pred_cond\n\n                    temp_x0 = sample_scheduler.step(\n                        noise_pred[0].unsqueeze(0),\n                        t,\n                        latents[0].unsqueeze(0),\n                        return_dict=False,\n                        generator=seed_g,\n                    )[0]\n                    latents[0] = temp_x0.squeeze(0)\n\n                    x0 = latents\n\n                x0 = [x.to(dtype=torch.float32) for x in x0]\n                out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]]))\n                \n                if start != 0:\n                    out_frames = out_frames[:, :, refert_num:]\n\n                all_out_frames.append(out_frames.cpu())\n\n                start += clip_len - refert_num\n                end += clip_len - refert_num\n\n        videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]\n        return videos[0] if self.rank == 0 else None\n"
  },
  {
    "path": "wan/configs/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport copy\nimport os\n\nos.environ['TOKENIZERS_PARALLELISM'] = 'false'\n\nfrom .wan_i2v_A14B import i2v_A14B\nfrom .wan_s2v_14B import s2v_14B\nfrom .wan_t2v_A14B import t2v_A14B\nfrom .wan_ti2v_5B import ti2v_5B\nfrom .wan_animate_14B import animate_14B\n\nWAN_CONFIGS = {\n    't2v-A14B': t2v_A14B,\n    'i2v-A14B': i2v_A14B,\n    'ti2v-5B': ti2v_5B,\n    'animate-14B': animate_14B,\n    's2v-14B': s2v_14B,\n}\n\nSIZE_CONFIGS = {\n    '720*1280': (720, 1280),\n    '1280*720': (1280, 720),\n    '480*832': (480, 832),\n    '832*480': (832, 480),\n    '704*1280': (704, 1280),\n    '1280*704': (1280, 704),\n    '1024*704': (1024, 704),\n    '704*1024': (704, 1024),\n}\n\nMAX_AREA_CONFIGS = {\n    '720*1280': 720 * 1280,\n    '1280*720': 1280 * 720,\n    '480*832': 480 * 832,\n    '832*480': 832 * 480,\n    '704*1280': 704 * 1280,\n    '1280*704': 1280 * 704,\n    '1024*704': 1024 * 704,\n    '704*1024': 704 * 1024,\n}\n\nSUPPORTED_SIZES = {\n    't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),\n    'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),\n    'ti2v-5B': ('704*1280', '1280*704'),\n    's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704',\n                '704*1024', '704*1280', '1280*704'),\n    'animate-14B': ('720*1280', '1280*720')\n}\n"
  },
  {
    "path": "wan/configs/shared_config.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n#------------------------ Wan shared config ------------------------#\nwan_shared_cfg = EasyDict()\n\n# t5\nwan_shared_cfg.t5_model = 'umt5_xxl'\nwan_shared_cfg.t5_dtype = torch.bfloat16\nwan_shared_cfg.text_len = 512\n\n# transformer\nwan_shared_cfg.param_dtype = torch.bfloat16\n\n# inference\nwan_shared_cfg.num_train_timesteps = 1000\nwan_shared_cfg.sample_fps = 16\nwan_shared_cfg.sample_neg_prompt = '色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走'\nwan_shared_cfg.frame_num = 81\n"
  },
  {
    "path": "wan/configs/wan_animate_14B.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_config import wan_shared_cfg\n\n#------------------------ Wan animate 14B ------------------------#\nanimate_14B = EasyDict(__name__='Config: Wan animate 14B')\nanimate_14B.update(wan_shared_cfg)\n\nanimate_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'\nanimate_14B.t5_tokenizer = 'google/umt5-xxl'\n\nanimate_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'\nanimate_14B.clip_tokenizer = 'xlm-roberta-large'\nanimate_14B.lora_checkpoint = 'relighting_lora.ckpt'\n# vae\nanimate_14B.vae_checkpoint = 'Wan2.1_VAE.pth'\nanimate_14B.vae_stride = (4, 8, 8)\n\n# transformer\nanimate_14B.patch_size = (1, 2, 2)\nanimate_14B.dim = 5120\nanimate_14B.ffn_dim = 13824\nanimate_14B.freq_dim = 256\nanimate_14B.num_heads = 40\nanimate_14B.num_layers = 40\nanimate_14B.window_size = (-1, -1)\nanimate_14B.qk_norm = True\nanimate_14B.cross_attn_norm = True\nanimate_14B.eps = 1e-6\nanimate_14B.use_face_encoder = True\nanimate_14B.motion_encoder_dim = 512\n\n# inference\nanimate_14B.sample_shift = 5.0\nanimate_14B.sample_steps = 20\nanimate_14B.sample_guide_scale = 1.0\nanimate_14B.frame_num = 77\nanimate_14B.sample_fps = 30\nanimate_14B.prompt = '视频中的人在做动作'\n"
  },
  {
    "path": "wan/configs/wan_i2v_A14B.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfrom .shared_config import wan_shared_cfg\n\n#------------------------ Wan I2V A14B ------------------------#\n\ni2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')\ni2v_A14B.update(wan_shared_cfg)\n\ni2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'\ni2v_A14B.t5_tokenizer = 'google/umt5-xxl'\n\n# vae\ni2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'\ni2v_A14B.vae_stride = (4, 8, 8)\n\n# transformer\ni2v_A14B.patch_size = (1, 2, 2)\ni2v_A14B.dim = 5120\ni2v_A14B.ffn_dim = 13824\ni2v_A14B.freq_dim = 256\ni2v_A14B.num_heads = 40\ni2v_A14B.num_layers = 40\ni2v_A14B.window_size = (-1, -1)\ni2v_A14B.qk_norm = True\ni2v_A14B.cross_attn_norm = True\ni2v_A14B.eps = 1e-6\ni2v_A14B.low_noise_checkpoint = 'low_noise_model'\ni2v_A14B.high_noise_checkpoint = 'high_noise_model'\n\n# inference\ni2v_A14B.sample_shift = 5.0\ni2v_A14B.sample_steps = 40\ni2v_A14B.boundary = 0.900\ni2v_A14B.sample_guide_scale = (3.5, 3.5)  # low noise, high noise\n"
  },
  {
    "path": "wan/configs/wan_s2v_14B.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_config import wan_shared_cfg\n\n#------------------------ Wan S2V 14B ------------------------#\n\ns2v_14B = EasyDict(__name__='Config: Wan S2V 14B')\ns2v_14B.update(wan_shared_cfg)\n\n# t5\ns2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'\ns2v_14B.t5_tokenizer = 'google/umt5-xxl'\n\n# vae\ns2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'\ns2v_14B.vae_stride = (4, 8, 8)\n\n# wav2vec\ns2v_14B.wav2vec = \"wav2vec2-large-xlsr-53-english\"\n\ns2v_14B.num_heads = 40\n# transformer\ns2v_14B.transformer = EasyDict(\n    __name__=\"Config: Transformer config for WanModel_S2V\")\ns2v_14B.transformer.patch_size = (1, 2, 2)\ns2v_14B.transformer.dim = 5120\ns2v_14B.transformer.ffn_dim = 13824\ns2v_14B.transformer.freq_dim = 256\ns2v_14B.transformer.num_heads = 40\ns2v_14B.transformer.num_layers = 40\ns2v_14B.transformer.window_size = (-1, -1)\ns2v_14B.transformer.qk_norm = True\ns2v_14B.transformer.cross_attn_norm = True\ns2v_14B.transformer.eps = 1e-6\ns2v_14B.transformer.enable_adain = True\ns2v_14B.transformer.adain_mode = \"attn_norm\"\ns2v_14B.transformer.audio_inject_layers = [\n    0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39\n]\ns2v_14B.transformer.zero_init = True\ns2v_14B.transformer.zero_timestep = True\ns2v_14B.transformer.enable_motioner = False\ns2v_14B.transformer.add_last_motion = True\ns2v_14B.transformer.trainable_token = False\ns2v_14B.transformer.enable_tsm = False\ns2v_14B.transformer.enable_framepack = True\ns2v_14B.transformer.framepack_drop_mode = 'padd'\ns2v_14B.transformer.audio_dim = 1024\n\ns2v_14B.transformer.motion_frames = 73\ns2v_14B.transformer.cond_dim = 16\n\n# inference\ns2v_14B.sample_neg_prompt = \"画面模糊，最差质量，画面模糊，细节模糊不清，情绪激动剧烈，手快速抖动，字幕，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走\"\ns2v_14B.drop_first_motion = True\ns2v_14B.sample_shift = 3\ns2v_14B.sample_steps = 40\ns2v_14B.sample_guide_scale = 4.5\n"
  },
  {
    "path": "wan/configs/wan_t2v_A14B.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_config import wan_shared_cfg\n\n#------------------------ Wan T2V A14B ------------------------#\n\nt2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')\nt2v_A14B.update(wan_shared_cfg)\n\n# t5\nt2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'\nt2v_A14B.t5_tokenizer = 'google/umt5-xxl'\n\n# vae\nt2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'\nt2v_A14B.vae_stride = (4, 8, 8)\n\n# transformer\nt2v_A14B.patch_size = (1, 2, 2)\nt2v_A14B.dim = 5120\nt2v_A14B.ffn_dim = 13824\nt2v_A14B.freq_dim = 256\nt2v_A14B.num_heads = 40\nt2v_A14B.num_layers = 40\nt2v_A14B.window_size = (-1, -1)\nt2v_A14B.qk_norm = True\nt2v_A14B.cross_attn_norm = True\nt2v_A14B.eps = 1e-6\nt2v_A14B.low_noise_checkpoint = 'low_noise_model'\nt2v_A14B.high_noise_checkpoint = 'high_noise_model'\n\n# inference\nt2v_A14B.sample_shift = 12.0\nt2v_A14B.sample_steps = 40\nt2v_A14B.boundary = 0.875\nt2v_A14B.sample_guide_scale = (3.0, 4.0)  # low noise, high noise\n"
  },
  {
    "path": "wan/configs/wan_ti2v_5B.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_config import wan_shared_cfg\n\n#------------------------ Wan TI2V 5B ------------------------#\n\nti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')\nti2v_5B.update(wan_shared_cfg)\n\n# t5\nti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'\nti2v_5B.t5_tokenizer = 'google/umt5-xxl'\n\n# vae\nti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'\nti2v_5B.vae_stride = (4, 16, 16)\n\n# transformer\nti2v_5B.patch_size = (1, 2, 2)\nti2v_5B.dim = 3072\nti2v_5B.ffn_dim = 14336\nti2v_5B.freq_dim = 256\nti2v_5B.num_heads = 24\nti2v_5B.num_layers = 30\nti2v_5B.window_size = (-1, -1)\nti2v_5B.qk_norm = True\nti2v_5B.cross_attn_norm = True\nti2v_5B.eps = 1e-6\n\n# inference\nti2v_5B.sample_fps = 24\nti2v_5B.sample_shift = 5.0\nti2v_5B.sample_steps = 50\nti2v_5B.sample_guide_scale = 5.0\nti2v_5B.frame_num = 121\n"
  },
  {
    "path": "wan/distributed/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n"
  },
  {
    "path": "wan/distributed/fsdp.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nfrom functools import partial\n\nimport torch\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp import MixedPrecision, ShardingStrategy\nfrom torch.distributed.fsdp.wrap import lambda_auto_wrap_policy\nfrom torch.distributed.utils import _free_storage\n\n\ndef shard_model(\n    model,\n    device_id,\n    param_dtype=torch.bfloat16,\n    reduce_dtype=torch.float32,\n    buffer_dtype=torch.float32,\n    process_group=None,\n    sharding_strategy=ShardingStrategy.FULL_SHARD,\n    sync_module_states=True,\n    use_lora=False\n):\n    model = FSDP(\n        module=model,\n        process_group=process_group,\n        sharding_strategy=sharding_strategy,\n        auto_wrap_policy=partial(\n            lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),\n        mixed_precision=MixedPrecision(\n            param_dtype=param_dtype,\n            reduce_dtype=reduce_dtype,\n            buffer_dtype=buffer_dtype),\n        device_id=device_id,\n        sync_module_states=sync_module_states,\n        use_orig_params=True if use_lora else False)\n    return model\n\n\ndef free_model(model):\n    for m in model.modules():\n        if isinstance(m, FSDP):\n            _free_storage(m._handle.flat_param.data)\n    del model\n    gc.collect()\n    torch.cuda.empty_cache()\n"
  },
  {
    "path": "wan/distributed/sequence_parallel.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\n\nfrom ..modules.model import sinusoidal_embedding_1d\nfrom .ulysses import distributed_attention\nfrom .util import gather_forward, get_rank, get_world_size\n\n\ndef pad_freqs(original_tensor, target_len):\n    seq_len, s1, s2 = original_tensor.shape\n    pad_size = target_len - seq_len\n    padding_tensor = torch.ones(\n        pad_size,\n        s1,\n        s2,\n        dtype=original_tensor.dtype,\n        device=original_tensor.device)\n    padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)\n    return padded_tensor\n\n\n@torch.amp.autocast('cuda', enabled=False)\ndef rope_apply(x, grid_sizes, freqs):\n    \"\"\"\n    x:          [B, L, N, C].\n    grid_sizes: [B, 3].\n    freqs:      [M, C // 2].\n    \"\"\"\n    s, n, c = x.size(1), x.size(2), x.size(3) // 2\n    # split freqs\n    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)\n\n    # loop over samples\n    output = []\n    for i, (f, h, w) in enumerate(grid_sizes.tolist()):\n        seq_len = f * h * w\n\n        # precompute multipliers\n        x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(\n            s, n, -1, 2))\n        freqs_i = torch.cat([\n            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n        ],\n                            dim=-1).reshape(seq_len, 1, -1)\n\n        # apply rotary embedding\n        sp_size = get_world_size()\n        sp_rank = get_rank()\n        freqs_i = pad_freqs(freqs_i, s * sp_size)\n        s_per_rank = s\n        freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *\n                                                       s_per_rank), :, :]\n        x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)\n        x_i = torch.cat([x_i, x[i, s:]])\n\n        # append to collection\n        output.append(x_i)\n    return torch.stack(output).float()\n\n\ndef sp_dit_forward(\n    self,\n    x,\n    t,\n    context,\n    seq_len,\n    y=None,\n):\n    \"\"\"\n    x:              A list of videos each with shape [C, T, H, W].\n    t:              [B].\n    context:        A list of text embeddings each with shape [L, C].\n    \"\"\"\n    if self.model_type == 'i2v':\n        assert y is not None\n    # params\n    device = self.patch_embedding.weight.device\n    if self.freqs.device != device:\n        self.freqs = self.freqs.to(device)\n\n    if y is not None:\n        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]\n\n    # embeddings\n    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]\n    grid_sizes = torch.stack(\n        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])\n    x = [u.flatten(2).transpose(1, 2) for u in x]\n    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)\n    assert seq_lens.max() <= seq_len\n    x = torch.cat([\n        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)\n        for u in x\n    ])\n\n    # time embeddings\n    if t.dim() == 1:\n        t = t.expand(t.size(0), seq_len)\n    with torch.amp.autocast('cuda', dtype=torch.float32):\n        bt = t.size(0)\n        t = t.flatten()\n        e = self.time_embedding(\n            sinusoidal_embedding_1d(self.freq_dim,\n                                    t).unflatten(0, (bt, seq_len)).float())\n        e0 = self.time_projection(e).unflatten(2, (6, self.dim))\n        assert e.dtype == torch.float32 and e0.dtype == torch.float32\n\n    # context\n    context_lens = None\n    context = self.text_embedding(\n        torch.stack([\n            torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])\n            for u in context\n        ]))\n\n    # Context Parallel\n    x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]\n    e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]\n    e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]\n\n    # arguments\n    kwargs = dict(\n        e=e0,\n        seq_lens=seq_lens,\n        grid_sizes=grid_sizes,\n        freqs=self.freqs,\n        context=context,\n        context_lens=context_lens)\n\n    for block in self.blocks:\n        x = block(x, **kwargs)\n\n    # head\n    x = self.head(x, e)\n\n    # Context Parallel\n    x = gather_forward(x, dim=1)\n\n    # unpatchify\n    x = self.unpatchify(x, grid_sizes)\n    return [u.float() for u in x]\n\n\ndef sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):\n    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n    half_dtypes = (torch.float16, torch.bfloat16)\n\n    def half(x):\n        return x if x.dtype in half_dtypes else x.to(dtype)\n\n    # query, key, value function\n    def qkv_fn(x):\n        q = self.norm_q(self.q(x)).view(b, s, n, d)\n        k = self.norm_k(self.k(x)).view(b, s, n, d)\n        v = self.v(x).view(b, s, n, d)\n        return q, k, v\n\n    q, k, v = qkv_fn(x)\n    q = rope_apply(q, grid_sizes, freqs)\n    k = rope_apply(k, grid_sizes, freqs)\n\n    x = distributed_attention(\n        half(q),\n        half(k),\n        half(v),\n        seq_lens,\n        window_size=self.window_size,\n    )\n\n    # output\n    x = x.flatten(2)\n    x = self.o(x)\n    return x\n"
  },
  {
    "path": "wan/distributed/ulysses.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.distributed as dist\n\nfrom ..modules.attention import flash_attention\nfrom .util import all_to_all\n\n\ndef distributed_attention(\n        q,\n        k,\n        v,\n        seq_lens,\n        window_size=(-1, -1),\n):\n    \"\"\"\n    Performs distributed attention based on DeepSpeed Ulysses attention mechanism.\n    please refer to https://arxiv.org/pdf/2309.14509\n\n    Args:\n        q:           [B, Lq // p, Nq, C1].\n        k:           [B, Lk // p, Nk, C1].\n        v:           [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.\n        seq_lens:    [B], length of each sequence in batch\n        window_size: (left right). If not (-1, -1), apply sliding window local attention.\n    \"\"\"\n    if not dist.is_initialized():\n        raise ValueError(\"distributed group should be initialized.\")\n    b = q.shape[0]\n\n    # gather q/k/v sequence\n    q = all_to_all(q, scatter_dim=2, gather_dim=1)\n    k = all_to_all(k, scatter_dim=2, gather_dim=1)\n    v = all_to_all(v, scatter_dim=2, gather_dim=1)\n\n    # apply attention\n    x = flash_attention(\n        q,\n        k,\n        v,\n        k_lens=seq_lens,\n        window_size=window_size,\n    )\n\n    # scatter q/k/v sequence\n    x = all_to_all(x, scatter_dim=1, gather_dim=2)\n    return x\n"
  },
  {
    "path": "wan/distributed/util.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.distributed as dist\n\n\ndef init_distributed_group():\n    \"\"\"r initialize sequence parallel group.\n    \"\"\"\n    if not dist.is_initialized():\n        dist.init_process_group(backend='nccl')\n\n\ndef get_rank():\n    return dist.get_rank()\n\n\ndef get_world_size():\n    return dist.get_world_size()\n\n\ndef all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):\n    \"\"\"\n    `scatter` along one dimension and `gather` along another.\n    \"\"\"\n    world_size = get_world_size()\n    if world_size > 1:\n        inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]\n        outputs = [torch.empty_like(u) for u in inputs]\n        dist.all_to_all(outputs, inputs, group=group, **kwargs)\n        x = torch.cat(outputs, dim=gather_dim).contiguous()\n    return x\n\n\ndef all_gather(tensor):\n    world_size = dist.get_world_size()\n    if world_size == 1:\n        return [tensor]\n    tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]\n    torch.distributed.all_gather(tensor_list, tensor)\n    return tensor_list\n\n\ndef gather_forward(input, dim):\n    # skip if world_size == 1\n    world_size = dist.get_world_size()\n    if world_size == 1:\n        return input\n\n    # gather sequence\n    output = all_gather(input)\n    return torch.cat(output, dim=dim).contiguous()\n"
  },
  {
    "path": "wan/image2video.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport types\nfrom contextlib import contextmanager\nfrom functools import partial\n\nimport numpy as np\nimport torch\nimport torch.cuda.amp as amp\nimport torch.distributed as dist\nimport torchvision.transforms.functional as TF\nfrom tqdm import tqdm\n\nfrom .distributed.fsdp import shard_model\nfrom .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward\nfrom .distributed.util import get_world_size\nfrom .modules.model import WanModel\nfrom .modules.t5 import T5EncoderModel\nfrom .modules.vae2_1 import Wan2_1_VAE\nfrom .utils.fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler\n\n\nclass WanI2V:\n\n    def __init__(\n        self,\n        config,\n        checkpoint_dir,\n        device_id=0,\n        rank=0,\n        t5_fsdp=False,\n        dit_fsdp=False,\n        use_sp=False,\n        t5_cpu=False,\n        init_on_cpu=True,\n        convert_model_dtype=False,\n    ):\n        r\"\"\"\n        Initializes the image-to-video generation model components.\n\n        Args:\n            config (EasyDict):\n                Object containing model parameters initialized from config.py\n            checkpoint_dir (`str`):\n                Path to directory containing model checkpoints\n            device_id (`int`,  *optional*, defaults to 0):\n                Id of target GPU device\n            rank (`int`,  *optional*, defaults to 0):\n                Process rank for distributed training\n            t5_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for T5 model\n            dit_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for DiT model\n            use_sp (`bool`, *optional*, defaults to False):\n                Enable distribution strategy of sequence parallel.\n            t5_cpu (`bool`, *optional*, defaults to False):\n                Whether to place T5 model on CPU. Only works without t5_fsdp.\n            init_on_cpu (`bool`, *optional*, defaults to True):\n                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.\n            convert_model_dtype (`bool`, *optional*, defaults to False):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n        \"\"\"\n        self.device = torch.device(f\"cuda:{device_id}\")\n        self.config = config\n        self.rank = rank\n        self.t5_cpu = t5_cpu\n        self.init_on_cpu = init_on_cpu\n\n        self.num_train_timesteps = config.num_train_timesteps\n        self.boundary = config.boundary\n        self.param_dtype = config.param_dtype\n\n        if t5_fsdp or dit_fsdp or use_sp:\n            self.init_on_cpu = False\n\n        shard_fn = partial(shard_model, device_id=device_id)\n        self.text_encoder = T5EncoderModel(\n            text_len=config.text_len,\n            dtype=config.t5_dtype,\n            device=torch.device('cpu'),\n            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),\n            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),\n            shard_fn=shard_fn if t5_fsdp else None,\n        )\n\n        self.vae_stride = config.vae_stride\n        self.patch_size = config.patch_size\n        self.vae = Wan2_1_VAE(\n            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),\n            device=self.device)\n\n        logging.info(f\"Creating WanModel from {checkpoint_dir}\")\n        self.low_noise_model = WanModel.from_pretrained(\n            checkpoint_dir, subfolder=config.low_noise_checkpoint)\n        self.low_noise_model = self._configure_model(\n            model=self.low_noise_model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype)\n\n        self.high_noise_model = WanModel.from_pretrained(\n            checkpoint_dir, subfolder=config.high_noise_checkpoint)\n        self.high_noise_model = self._configure_model(\n            model=self.high_noise_model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype)\n        if use_sp:\n            self.sp_size = get_world_size()\n        else:\n            self.sp_size = 1\n\n        self.sample_neg_prompt = config.sample_neg_prompt\n\n    def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,\n                         convert_model_dtype):\n        \"\"\"\n        Configures a model object. This includes setting evaluation modes,\n        applying distributed parallel strategy, and handling device placement.\n\n        Args:\n            model (torch.nn.Module):\n                The model instance to configure.\n            use_sp (`bool`):\n                Enable distribution strategy of sequence parallel.\n            dit_fsdp (`bool`):\n                Enable FSDP sharding for DiT model.\n            shard_fn (callable):\n                The function to apply FSDP sharding.\n            convert_model_dtype (`bool`):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n\n        Returns:\n            torch.nn.Module:\n                The configured model.\n        \"\"\"\n        model.eval().requires_grad_(False)\n\n        if use_sp:\n            for block in model.blocks:\n                block.self_attn.forward = types.MethodType(\n                    sp_attn_forward, block.self_attn)\n            model.forward = types.MethodType(sp_dit_forward, model)\n\n        if dist.is_initialized():\n            dist.barrier()\n\n        if dit_fsdp:\n            model = shard_fn(model)\n        else:\n            if convert_model_dtype:\n                model.to(self.param_dtype)\n            if not self.init_on_cpu:\n                model.to(self.device)\n\n        return model\n\n    def _prepare_model_for_timestep(self, t, boundary, offload_model):\n        r\"\"\"\n        Prepares and returns the required model for the current timestep.\n\n        Args:\n            t (torch.Tensor):\n                current timestep.\n            boundary (`int`):\n                The timestep threshold. If `t` is at or above this value,\n                the `high_noise_model` is considered as the required model.\n            offload_model (`bool`):\n                A flag intended to control the offloading behavior.\n\n        Returns:\n            torch.nn.Module:\n                The active model on the target device for the current timestep.\n        \"\"\"\n        if t.item() >= boundary:\n            required_model_name = 'high_noise_model'\n            offload_model_name = 'low_noise_model'\n        else:\n            required_model_name = 'low_noise_model'\n            offload_model_name = 'high_noise_model'\n        if offload_model or self.init_on_cpu:\n            if next(getattr(\n                    self,\n                    offload_model_name).parameters()).device.type == 'cuda':\n                getattr(self, offload_model_name).to('cpu')\n            if next(getattr(\n                    self,\n                    required_model_name).parameters()).device.type == 'cpu':\n                getattr(self, required_model_name).to(self.device)\n        return getattr(self, required_model_name)\n\n    def generate(self,\n                 input_prompt,\n                 img,\n                 max_area=720 * 1280,\n                 frame_num=81,\n                 shift=5.0,\n                 sample_solver='unipc',\n                 sampling_steps=40,\n                 guide_scale=5.0,\n                 n_prompt=\"\",\n                 seed=-1,\n                 offload_model=True):\n        r\"\"\"\n        Generates video frames from input image and text prompt using diffusion process.\n\n        Args:\n            input_prompt (`str`):\n                Text prompt for content generation.\n            img (PIL.Image.Image):\n                Input image tensor. Shape: [3, H, W]\n            max_area (`int`, *optional*, defaults to 720*1280):\n                Maximum pixel area for latent space calculation. Controls video resolution scaling\n            frame_num (`int`, *optional*, defaults to 81):\n                How many frames to sample from a video. The number should be 4n+1\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. Affects temporal dynamics\n                [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.\n            sample_solver (`str`, *optional*, defaults to 'unipc'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 40):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):\n                Classifier-free guidance scale. Controls prompt adherence vs. creativity.\n                If tuple, the first guide_scale will be used for low noise model and\n                the second guide_scale will be used for high noise model.\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames (81)\n                - H: Frame height (from max_area)\n                - W: Frame width from max_area)\n        \"\"\"\n        # preprocess\n        guide_scale = (guide_scale, guide_scale) if isinstance(\n            guide_scale, float) else guide_scale\n        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)\n\n        F = frame_num\n        h, w = img.shape[1:]\n        aspect_ratio = h / w\n        lat_h = round(\n            np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //\n            self.patch_size[1] * self.patch_size[1])\n        lat_w = round(\n            np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //\n            self.patch_size[2] * self.patch_size[2])\n        h = lat_h * self.vae_stride[1]\n        w = lat_w * self.vae_stride[2]\n\n        max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (\n            self.patch_size[1] * self.patch_size[2])\n        max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size\n\n        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)\n        seed_g = torch.Generator(device=self.device)\n        seed_g.manual_seed(seed)\n        noise = torch.randn(\n            16,\n            (F - 1) // self.vae_stride[0] + 1,\n            lat_h,\n            lat_w,\n            dtype=torch.float32,\n            generator=seed_g,\n            device=self.device)\n\n        msk = torch.ones(1, F, lat_h, lat_w, device=self.device)\n        msk[:, 1:] = 0\n        msk = torch.concat([\n            torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]\n        ],\n                           dim=1)\n        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)\n        msk = msk.transpose(1, 2)[0]\n\n        if n_prompt == \"\":\n            n_prompt = self.sample_neg_prompt\n\n        # preprocess\n        if not self.t5_cpu:\n            self.text_encoder.model.to(self.device)\n            context = self.text_encoder([input_prompt], self.device)\n            context_null = self.text_encoder([n_prompt], self.device)\n            if offload_model:\n                self.text_encoder.model.cpu()\n        else:\n            context = self.text_encoder([input_prompt], torch.device('cpu'))\n            context_null = self.text_encoder([n_prompt], torch.device('cpu'))\n            context = [t.to(self.device) for t in context]\n            context_null = [t.to(self.device) for t in context_null]\n\n        y = self.vae.encode([\n            torch.concat([\n                torch.nn.functional.interpolate(\n                    img[None].cpu(), size=(h, w), mode='bicubic').transpose(\n                        0, 1),\n                torch.zeros(3, F - 1, h, w)\n            ],\n                         dim=1).to(self.device)\n        ])[0]\n        y = torch.concat([msk, y])\n\n        @contextmanager\n        def noop_no_sync():\n            yield\n\n        no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',\n                                    noop_no_sync)\n        no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',\n                                     noop_no_sync)\n\n        # evaluation mode\n        with (\n                torch.amp.autocast('cuda', dtype=self.param_dtype),\n                torch.no_grad(),\n                no_sync_low_noise(),\n                no_sync_high_noise(),\n        ):\n            boundary = self.boundary * self.num_train_timesteps\n\n            if sample_solver == 'unipc':\n                sample_scheduler = FlowUniPCMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sample_scheduler.set_timesteps(\n                    sampling_steps, device=self.device, shift=shift)\n                timesteps = sample_scheduler.timesteps\n            elif sample_solver == 'dpm++':\n                sample_scheduler = FlowDPMSolverMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)\n                timesteps, _ = retrieve_timesteps(\n                    sample_scheduler,\n                    device=self.device,\n                    sigmas=sampling_sigmas)\n            else:\n                raise NotImplementedError(\"Unsupported solver.\")\n\n            # sample videos\n            latent = noise\n\n            arg_c = {\n                'context': [context[0]],\n                'seq_len': max_seq_len,\n                'y': [y],\n            }\n\n            arg_null = {\n                'context': context_null,\n                'seq_len': max_seq_len,\n                'y': [y],\n            }\n\n            if offload_model:\n                torch.cuda.empty_cache()\n\n            for _, t in enumerate(tqdm(timesteps)):\n                latent_model_input = [latent.to(self.device)]\n                timestep = [t]\n\n                timestep = torch.stack(timestep).to(self.device)\n\n                model = self._prepare_model_for_timestep(\n                    t, boundary, offload_model)\n                sample_guide_scale = guide_scale[1] if t.item(\n                ) >= boundary else guide_scale[0]\n\n                noise_pred_cond = model(\n                    latent_model_input, t=timestep, **arg_c)[0]\n                if offload_model:\n                    torch.cuda.empty_cache()\n                noise_pred_uncond = model(\n                    latent_model_input, t=timestep, **arg_null)[0]\n                if offload_model:\n                    torch.cuda.empty_cache()\n                noise_pred = noise_pred_uncond + sample_guide_scale * (\n                    noise_pred_cond - noise_pred_uncond)\n\n                temp_x0 = sample_scheduler.step(\n                    noise_pred.unsqueeze(0),\n                    t,\n                    latent.unsqueeze(0),\n                    return_dict=False,\n                    generator=seed_g)[0]\n                latent = temp_x0.squeeze(0)\n\n                x0 = [latent]\n                del latent_model_input, timestep\n\n            if offload_model:\n                self.low_noise_model.cpu()\n                self.high_noise_model.cpu()\n                torch.cuda.empty_cache()\n\n            if self.rank == 0:\n                videos = self.vae.decode(x0)\n\n        del noise, latent, x0\n        del sample_scheduler\n        if offload_model:\n            gc.collect()\n            torch.cuda.synchronize()\n        if dist.is_initialized():\n            dist.barrier()\n\n        return videos[0] if self.rank == 0 else None\n"
  },
  {
    "path": "wan/modules/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model\nfrom .tokenizers import HuggingfaceTokenizer\nfrom .vae2_1 import Wan2_1_VAE\nfrom .vae2_2 import Wan2_2_VAE\n\n__all__ = [\n    'Wan2_1_VAE',\n    'Wan2_2_VAE',\n    'WanModel',\n    'T5Model',\n    'T5Encoder',\n    'T5Decoder',\n    'T5EncoderModel',\n    'HuggingfaceTokenizer',\n    'flash_attention',\n]\n"
  },
  {
    "path": "wan/modules/animate/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .model_animate import WanAnimateModel\nfrom .clip import CLIPModel\n__all__ = ['WanAnimateModel', 'CLIPModel']"
  },
  {
    "path": "wan/modules/animate/animate_utils.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport numbers\nfrom peft import LoraConfig\n\n\ndef get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights=\"gaussian\"):\n    target_modules = []\n    for name, module in transformer.named_modules():\n        if \"blocks\" in name and \"face\" not in name and \"modulation\" not in name and isinstance(module, torch.nn.Linear):\n            target_modules.append(name)\n\n    transformer_lora_config = LoraConfig(\n        r=rank,\n        lora_alpha=alpha,\n        init_lora_weights=init_lora_weights,\n        target_modules=target_modules,\n    )\n    return transformer_lora_config\n\n\n\nclass TensorList(object):\n\n    def __init__(self, tensors):\n        \"\"\"\n        tensors: a list of torch.Tensor objects. No need to have uniform shape.\n        \"\"\"\n        assert isinstance(tensors, (list, tuple))\n        assert all(isinstance(u, torch.Tensor) for u in tensors)\n        assert len(set([u.ndim for u in tensors])) == 1\n        assert len(set([u.dtype for u in tensors])) == 1\n        assert len(set([u.device for u in tensors])) == 1\n        self.tensors = tensors\n    \n    def to(self, *args, **kwargs):\n        return TensorList([u.to(*args, **kwargs) for u in self.tensors])\n    \n    def size(self, dim):\n        assert dim == 0, 'only support get the 0th size'\n        return len(self.tensors)\n    \n    def pow(self, *args, **kwargs):\n        return TensorList([u.pow(*args, **kwargs) for u in self.tensors])\n    \n    def squeeze(self, dim):\n        assert dim != 0\n        if dim > 0:\n            dim -= 1\n        return TensorList([u.squeeze(dim) for u in self.tensors])\n    \n    def type(self, *args, **kwargs):\n        return TensorList([u.type(*args, **kwargs) for u in self.tensors])\n    \n    def type_as(self, other):\n        assert isinstance(other, (torch.Tensor, TensorList))\n        if isinstance(other, torch.Tensor):\n            return TensorList([u.type_as(other) for u in self.tensors])\n        else:\n            return TensorList([u.type(other.dtype) for u in self.tensors])\n    \n    @property\n    def dtype(self):\n        return self.tensors[0].dtype\n    \n    @property\n    def device(self):\n        return self.tensors[0].device\n    \n    @property\n    def ndim(self):\n        return 1 + self.tensors[0].ndim\n    \n    def __getitem__(self, index):\n        return self.tensors[index]\n    \n    def __len__(self):\n        return len(self.tensors)\n    \n    def __add__(self, other):\n        return self._apply(other, lambda u, v: u + v)\n    \n    def __radd__(self, other):\n        return self._apply(other, lambda u, v: v + u)\n    \n    def __sub__(self, other):\n        return self._apply(other, lambda u, v: u - v)\n    \n    def __rsub__(self, other):\n        return self._apply(other, lambda u, v: v - u)\n    \n    def __mul__(self, other):\n        return self._apply(other, lambda u, v: u * v)\n    \n    def __rmul__(self, other):\n        return self._apply(other, lambda u, v: v * u)\n    \n    def __floordiv__(self, other):\n        return self._apply(other, lambda u, v: u // v)\n    \n    def __truediv__(self, other):\n        return self._apply(other, lambda u, v: u / v)\n    \n    def __rfloordiv__(self, other):\n        return self._apply(other, lambda u, v: v // u)\n    \n    def __rtruediv__(self, other):\n        return self._apply(other, lambda u, v: v / u)\n    \n    def __pow__(self, other):\n        return self._apply(other, lambda u, v: u ** v)\n    \n    def __rpow__(self, other):\n        return self._apply(other, lambda u, v: v ** u)\n    \n    def __neg__(self):\n        return TensorList([-u for u in self.tensors])\n    \n    def __iter__(self):\n        for tensor in self.tensors:\n            yield tensor\n    \n    def __repr__(self):\n        return 'TensorList: \\n' + repr(self.tensors)\n\n    def _apply(self, other, op):\n        if isinstance(other, (list, tuple, TensorList)) or (\n            isinstance(other, torch.Tensor) and (\n                other.numel() > 1 or other.ndim > 1\n            )\n        ):\n            assert len(other) == len(self.tensors)\n            return TensorList([op(u, v) for u, v in zip(self.tensors, other)])\n        elif isinstance(other, numbers.Number) or (\n            isinstance(other, torch.Tensor) and (\n                other.numel() == 1 and other.ndim <= 1\n            )\n        ):\n            return TensorList([op(u, other) for u in self.tensors])\n        else:\n            raise TypeError(\n                f'unsupported operand for *: \"TensorList\" and \"{type(other)}\"'\n            )"
  },
  {
    "path": "wan/modules/animate/clip.py",
    "content": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as T\n\nfrom ..attention import flash_attention\nfrom ..tokenizers import HuggingfaceTokenizer\nfrom .xlm_roberta import XLMRoberta\n\n__all__ = [\n    'XLMRobertaCLIP',\n    'clip_xlm_roberta_vit_h_14',\n    'CLIPModel',\n]\n\n\ndef pos_interpolate(pos, seq_len):\n    if pos.size(1) == seq_len:\n        return pos\n    else:\n        src_grid = int(math.sqrt(pos.size(1)))\n        tar_grid = int(math.sqrt(seq_len))\n        n = pos.size(1) - src_grid * src_grid\n        return torch.cat([\n            pos[:, :n],\n            F.interpolate(\n                pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(\n                    0, 3, 1, 2),\n                size=(tar_grid, tar_grid),\n                mode='bicubic',\n                align_corners=False).flatten(2).transpose(1, 2)\n        ],\n                         dim=1)\n\n\nclass QuickGELU(nn.Module):\n\n    def forward(self, x):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass LayerNorm(nn.LayerNorm):\n\n    def forward(self, x):\n        return super().forward(x.float()).type_as(x)\n\n\nclass SelfAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 num_heads,\n                 causal=False,\n                 attn_dropout=0.0,\n                 proj_dropout=0.0):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.causal = causal\n        self.attn_dropout = attn_dropout\n        self.proj_dropout = proj_dropout\n\n        # layers\n        self.to_qkv = nn.Linear(dim, dim * 3)\n        self.proj = nn.Linear(dim, dim)\n\n    def forward(self, x):\n        \"\"\"\n        x:   [B, L, C].\n        \"\"\"\n        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)\n\n        # compute attention\n        p = self.attn_dropout if self.training else 0.0\n        x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)\n        x = x.reshape(b, s, c)\n\n        # output\n        x = self.proj(x)\n        x = F.dropout(x, self.proj_dropout, self.training)\n        return x\n\n\nclass SwiGLU(nn.Module):\n\n    def __init__(self, dim, mid_dim):\n        super().__init__()\n        self.dim = dim\n        self.mid_dim = mid_dim\n\n        # layers\n        self.fc1 = nn.Linear(dim, mid_dim)\n        self.fc2 = nn.Linear(dim, mid_dim)\n        self.fc3 = nn.Linear(mid_dim, dim)\n\n    def forward(self, x):\n        x = F.silu(self.fc1(x)) * self.fc2(x)\n        x = self.fc3(x)\n        return x\n\n\nclass AttentionBlock(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 mlp_ratio,\n                 num_heads,\n                 post_norm=False,\n                 causal=False,\n                 activation='quick_gelu',\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 norm_eps=1e-5):\n        assert activation in ['quick_gelu', 'gelu', 'swi_glu']\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.num_heads = num_heads\n        self.post_norm = post_norm\n        self.causal = causal\n        self.norm_eps = norm_eps\n\n        # layers\n        self.norm1 = LayerNorm(dim, eps=norm_eps)\n        self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,\n                                  proj_dropout)\n        self.norm2 = LayerNorm(dim, eps=norm_eps)\n        if activation == 'swi_glu':\n            self.mlp = SwiGLU(dim, int(dim * mlp_ratio))\n        else:\n            self.mlp = nn.Sequential(\n                nn.Linear(dim, int(dim * mlp_ratio)),\n                QuickGELU() if activation == 'quick_gelu' else nn.GELU(),\n                nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))\n\n    def forward(self, x):\n        if self.post_norm:\n            x = x + self.norm1(self.attn(x))\n            x = x + self.norm2(self.mlp(x))\n        else:\n            x = x + self.attn(self.norm1(x))\n            x = x + self.mlp(self.norm2(x))\n        return x\n\n\nclass AttentionPool(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 mlp_ratio,\n                 num_heads,\n                 activation='gelu',\n                 proj_dropout=0.0,\n                 norm_eps=1e-5):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.proj_dropout = proj_dropout\n        self.norm_eps = norm_eps\n\n        # layers\n        gain = 1.0 / math.sqrt(dim)\n        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))\n        self.to_q = nn.Linear(dim, dim)\n        self.to_kv = nn.Linear(dim, dim * 2)\n        self.proj = nn.Linear(dim, dim)\n        self.norm = LayerNorm(dim, eps=norm_eps)\n        self.mlp = nn.Sequential(\n            nn.Linear(dim, int(dim * mlp_ratio)),\n            QuickGELU() if activation == 'quick_gelu' else nn.GELU(),\n            nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))\n\n    def forward(self, x):\n        \"\"\"\n        x:  [B, L, C].\n        \"\"\"\n        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)\n        k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)\n\n        # compute attention\n        x = flash_attention(q, k, v, version=2)\n        x = x.reshape(b, 1, c)\n\n        # output\n        x = self.proj(x)\n        x = F.dropout(x, self.proj_dropout, self.training)\n\n        # mlp\n        x = x + self.mlp(self.norm(x))\n        return x[:, 0]\n\n\nclass VisionTransformer(nn.Module):\n\n    def __init__(self,\n                 image_size=224,\n                 patch_size=16,\n                 dim=768,\n                 mlp_ratio=4,\n                 out_dim=512,\n                 num_heads=12,\n                 num_layers=12,\n                 pool_type='token',\n                 pre_norm=True,\n                 post_norm=False,\n                 activation='quick_gelu',\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 embedding_dropout=0.0,\n                 norm_eps=1e-5):\n        if image_size % patch_size != 0:\n            print(\n                '[WARNING] image_size is not divisible by patch_size',\n                flush=True)\n        assert pool_type in ('token', 'token_fc', 'attn_pool')\n        out_dim = out_dim or dim\n        super().__init__()\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.num_patches = (image_size // patch_size)**2\n        self.dim = dim\n        self.mlp_ratio = mlp_ratio\n        self.out_dim = out_dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.pool_type = pool_type\n        self.post_norm = post_norm\n        self.norm_eps = norm_eps\n\n        # embeddings\n        gain = 1.0 / math.sqrt(dim)\n        self.patch_embedding = nn.Conv2d(\n            3,\n            dim,\n            kernel_size=patch_size,\n            stride=patch_size,\n            bias=not pre_norm)\n        if pool_type in ('token', 'token_fc'):\n            self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))\n        self.pos_embedding = nn.Parameter(gain * torch.randn(\n            1, self.num_patches +\n            (1 if pool_type in ('token', 'token_fc') else 0), dim))\n        self.dropout = nn.Dropout(embedding_dropout)\n\n        # transformer\n        self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None\n        self.transformer = nn.Sequential(*[\n            AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,\n                           activation, attn_dropout, proj_dropout, norm_eps)\n            for _ in range(num_layers)\n        ])\n        self.post_norm = LayerNorm(dim, eps=norm_eps)\n\n        # head\n        if pool_type == 'token':\n            self.head = nn.Parameter(gain * torch.randn(dim, out_dim))\n        elif pool_type == 'token_fc':\n            self.head = nn.Linear(dim, out_dim)\n        elif pool_type == 'attn_pool':\n            self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,\n                                      proj_dropout, norm_eps)\n\n    def forward(self, x, interpolation=False, use_31_block=False):\n        b = x.size(0)\n\n        # embeddings\n        x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)\n        if self.pool_type in ('token', 'token_fc'):\n            x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)\n        if interpolation:\n            e = pos_interpolate(self.pos_embedding, x.size(1))\n        else:\n            e = self.pos_embedding\n        x = self.dropout(x + e)\n        if self.pre_norm is not None:\n            x = self.pre_norm(x)\n\n        # transformer\n        if use_31_block:\n            x = self.transformer[:-1](x)\n            return x\n        else:\n            x = self.transformer(x)\n            return x\n\n\nclass XLMRobertaWithHead(XLMRoberta):\n\n    def __init__(self, **kwargs):\n        self.out_dim = kwargs.pop('out_dim')\n        super().__init__(**kwargs)\n\n        # head\n        mid_dim = (self.dim + self.out_dim) // 2\n        self.head = nn.Sequential(\n            nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),\n            nn.Linear(mid_dim, self.out_dim, bias=False))\n\n    def forward(self, ids):\n        # xlm-roberta\n        x = super().forward(ids)\n\n        # average pooling\n        mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)\n        x = (x * mask).sum(dim=1) / mask.sum(dim=1)\n\n        # head\n        x = self.head(x)\n        return x\n\n\nclass XLMRobertaCLIP(nn.Module):\n\n    def __init__(self,\n                 embed_dim=1024,\n                 image_size=224,\n                 patch_size=14,\n                 vision_dim=1280,\n                 vision_mlp_ratio=4,\n                 vision_heads=16,\n                 vision_layers=32,\n                 vision_pool='token',\n                 vision_pre_norm=True,\n                 vision_post_norm=False,\n                 activation='gelu',\n                 vocab_size=250002,\n                 max_text_len=514,\n                 type_size=1,\n                 pad_id=1,\n                 text_dim=1024,\n                 text_heads=16,\n                 text_layers=24,\n                 text_post_norm=True,\n                 text_dropout=0.1,\n                 attn_dropout=0.0,\n                 proj_dropout=0.0,\n                 embedding_dropout=0.0,\n                 norm_eps=1e-5):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.vision_dim = vision_dim\n        self.vision_mlp_ratio = vision_mlp_ratio\n        self.vision_heads = vision_heads\n        self.vision_layers = vision_layers\n        self.vision_pre_norm = vision_pre_norm\n        self.vision_post_norm = vision_post_norm\n        self.activation = activation\n        self.vocab_size = vocab_size\n        self.max_text_len = max_text_len\n        self.type_size = type_size\n        self.pad_id = pad_id\n        self.text_dim = text_dim\n        self.text_heads = text_heads\n        self.text_layers = text_layers\n        self.text_post_norm = text_post_norm\n        self.norm_eps = norm_eps\n\n        # models\n        self.visual = VisionTransformer(\n            image_size=image_size,\n            patch_size=patch_size,\n            dim=vision_dim,\n            mlp_ratio=vision_mlp_ratio,\n            out_dim=embed_dim,\n            num_heads=vision_heads,\n            num_layers=vision_layers,\n            pool_type=vision_pool,\n            pre_norm=vision_pre_norm,\n            post_norm=vision_post_norm,\n            activation=activation,\n            attn_dropout=attn_dropout,\n            proj_dropout=proj_dropout,\n            embedding_dropout=embedding_dropout,\n            norm_eps=norm_eps)\n        self.textual = XLMRobertaWithHead(\n            vocab_size=vocab_size,\n            max_seq_len=max_text_len,\n            type_size=type_size,\n            pad_id=pad_id,\n            dim=text_dim,\n            out_dim=embed_dim,\n            num_heads=text_heads,\n            num_layers=text_layers,\n            post_norm=text_post_norm,\n            dropout=text_dropout)\n        self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))\n\n    def forward(self, imgs, txt_ids):\n        \"\"\"\n        imgs:       [B, 3, H, W] of torch.float32.\n        - mean:     [0.48145466, 0.4578275, 0.40821073]\n        - std:      [0.26862954, 0.26130258, 0.27577711]\n        txt_ids:    [B, L] of torch.long.\n                    Encoded by data.CLIPTokenizer.\n        \"\"\"\n        xi = self.visual(imgs)\n        xt = self.textual(txt_ids)\n        return xi, xt\n\n    def param_groups(self):\n        groups = [{\n            'params': [\n                p for n, p in self.named_parameters()\n                if 'norm' in n or n.endswith('bias')\n            ],\n            'weight_decay': 0.0\n        }, {\n            'params': [\n                p for n, p in self.named_parameters()\n                if not ('norm' in n or n.endswith('bias'))\n            ]\n        }]\n        return groups\n\n\ndef _clip(pretrained=False,\n          pretrained_name=None,\n          model_cls=XLMRobertaCLIP,\n          return_transforms=False,\n          return_tokenizer=False,\n          tokenizer_padding='eos',\n          dtype=torch.float32,\n          device='cpu',\n          **kwargs):\n    # init a model on device\n    with torch.device(device):\n        model = model_cls(**kwargs)\n\n    # set device\n    model = model.to(dtype=dtype, device=device)\n    output = (model,)\n\n    # init transforms\n    if return_transforms:\n        # mean and std\n        if 'siglip' in pretrained_name.lower():\n            mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]\n        else:\n            mean = [0.48145466, 0.4578275, 0.40821073]\n            std = [0.26862954, 0.26130258, 0.27577711]\n\n        # transforms\n        transforms = T.Compose([\n            T.Resize((model.image_size, model.image_size),\n                     interpolation=T.InterpolationMode.BICUBIC),\n            T.ToTensor(),\n            T.Normalize(mean=mean, std=std)\n        ])\n        output += (transforms,)\n    return output[0] if len(output) == 1 else output\n\n\ndef clip_xlm_roberta_vit_h_14(\n        pretrained=False,\n        pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',\n        **kwargs):\n    cfg = dict(\n        embed_dim=1024,\n        image_size=224,\n        patch_size=14,\n        vision_dim=1280,\n        vision_mlp_ratio=4,\n        vision_heads=16,\n        vision_layers=32,\n        vision_pool='token',\n        activation='gelu',\n        vocab_size=250002,\n        max_text_len=514,\n        type_size=1,\n        pad_id=1,\n        text_dim=1024,\n        text_heads=16,\n        text_layers=24,\n        text_post_norm=True,\n        text_dropout=0.1,\n        attn_dropout=0.0,\n        proj_dropout=0.0,\n        embedding_dropout=0.0)\n    cfg.update(**kwargs)\n    return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)\n\n\nclass CLIPModel:\n\n    def __init__(self, dtype, device, checkpoint_path, tokenizer_path):\n        self.dtype = dtype\n        self.device = device\n        self.checkpoint_path = checkpoint_path\n        self.tokenizer_path = tokenizer_path\n\n        # init model\n        self.model, self.transforms = clip_xlm_roberta_vit_h_14(\n            pretrained=False,\n            return_transforms=True,\n            return_tokenizer=False,\n            dtype=dtype,\n            device=device)\n        self.model = self.model.eval().requires_grad_(False)\n        logging.info(f'loading {checkpoint_path}')\n        self.model.load_state_dict(\n            torch.load(checkpoint_path, map_location='cpu'))\n\n        # init tokenizer\n        self.tokenizer = HuggingfaceTokenizer(\n            name=tokenizer_path,\n            seq_len=self.model.max_text_len - 2,\n            clean='whitespace')\n\n    def visual(self, videos):\n        # preprocess\n        size = (self.model.image_size,) * 2\n        videos = torch.cat([\n            F.interpolate(\n                u.transpose(0, 1),\n                size=size,\n                mode='bicubic',\n                align_corners=False) for u in videos\n        ])\n        videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))\n\n        # forward\n        with torch.cuda.amp.autocast(dtype=self.dtype):\n            out = self.model.visual(videos, use_31_block=True)\n            return out"
  },
  {
    "path": "wan/modules/animate/face_blocks.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom torch import nn\nimport torch\nfrom typing import Tuple, Optional\nfrom einops import rearrange\nimport torch.nn.functional as F\nimport math\nfrom ...distributed.util import gather_forward, get_rank, get_world_size\n\n\ntry:\n    from flash_attn import flash_attn_qkvpacked_func, flash_attn_func\nexcept ImportError:\n    flash_attn_func = None\n\nMEMORY_LAYOUT = {\n    \"flash\": (\n        lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),\n        lambda x: x,\n    ),\n    \"torch\": (\n        lambda x: x.transpose(1, 2),\n        lambda x: x.transpose(1, 2),\n    ),\n    \"vanilla\": (\n        lambda x: x.transpose(1, 2),\n        lambda x: x.transpose(1, 2),\n    ),\n}\n\n\ndef attention(\n    q,\n    k,\n    v,\n    mode=\"flash\",\n    drop_rate=0,\n    attn_mask=None,\n    causal=False,\n    max_seqlen_q=None,\n    batch_size=1,\n):\n    \"\"\"\n    Perform QKV self attention.\n\n    Args:\n        q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.\n        k (torch.Tensor): Key tensor with shape [b, s1, a, d]\n        v (torch.Tensor): Value tensor with shape [b, s1, a, d]\n        mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.\n        drop_rate (float): Dropout rate in attention map. (default: 0)\n        attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).\n            (default: None)\n        causal (bool): Whether to use causal attention. (default: False)\n        cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,\n            used to index into q.\n        cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,\n            used to index into kv.\n        max_seqlen_q (int): The maximum sequence length in the batch of q.\n        max_seqlen_kv (int): The maximum sequence length in the batch of k and v.\n\n    Returns:\n        torch.Tensor: Output tensor after self attention with shape [b, s, ad]\n    \"\"\"\n    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]\n\n    if mode == \"torch\":\n        if attn_mask is not None and attn_mask.dtype != torch.bool:\n            attn_mask = attn_mask.to(q.dtype)\n        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)\n\n    elif mode == \"flash\":\n        x = flash_attn_func(\n            q,\n            k,\n            v,\n        )\n        x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1])  # reshape x to [b, s, a, d]\n    elif mode == \"vanilla\":\n        scale_factor = 1 / math.sqrt(q.size(-1))\n\n        b, a, s, _ = q.shape\n        s1 = k.size(2)\n        attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)\n        if causal:\n            # Only applied to self attention\n            assert attn_mask is None, \"Causal mask and attn_mask cannot be used together\"\n            temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)\n            attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n            attn_bias.to(q.dtype)\n\n        if attn_mask is not None:\n            if attn_mask.dtype == torch.bool:\n                attn_bias.masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n            else:\n                attn_bias += attn_mask\n\n        attn = (q @ k.transpose(-2, -1)) * scale_factor\n        attn += attn_bias\n        attn = attn.softmax(dim=-1)\n        attn = torch.dropout(attn, p=drop_rate, train=True)\n        x = attn @ v\n    else:\n        raise NotImplementedError(f\"Unsupported attention mode: {mode}\")\n\n    x = post_attn_layout(x)\n    b, s, a, d = x.shape\n    out = x.reshape(b, s, -1)\n    return out\n\n\nclass CausalConv1d(nn.Module):\n\n    def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode=\"replicate\", **kwargs):\n        super().__init__()\n\n        self.pad_mode = pad_mode\n        padding = (kernel_size - 1, 0)  # T\n        self.time_causal_padding = padding\n\n        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)\n\n    def forward(self, x):\n        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)\n        return self.conv(x)\n\n\n\nclass FaceEncoder(nn.Module):\n    def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n\n        self.num_heads = num_heads\n        self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)\n        self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n        self.act = nn.SiLU()\n        self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)\n        self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)\n\n        self.out_proj = nn.Linear(1024, hidden_dim)\n        self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))\n\n    def forward(self, x):\n        \n        x = rearrange(x, \"b t c -> b c t\")\n        b, c, t = x.shape\n\n        x = self.conv1_local(x)\n        x = rearrange(x, \"b (n c) t -> (b n) t c\", n=self.num_heads)\n        \n        x = self.norm1(x)\n        x = self.act(x)\n        x = rearrange(x, \"b t c -> b c t\")\n        x = self.conv2(x)\n        x = rearrange(x, \"b c t -> b t c\")\n        x = self.norm2(x)\n        x = self.act(x)\n        x = rearrange(x, \"b t c -> b c t\")\n        x = self.conv3(x)\n        x = rearrange(x, \"b c t -> b t c\")\n        x = self.norm3(x)\n        x = self.act(x)\n        x = self.out_proj(x)\n        x = rearrange(x, \"(b n) t c -> b t n c\", b=b)\n        padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)\n        x = torch.cat([x, padding], dim=-2)\n        x_local = x.clone()\n\n        return x_local\n\n\n\nclass RMSNorm(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        elementwise_affine=True,\n        eps: float = 1e-6,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        Initialize the RMSNorm normalization layer.\n\n        Args:\n            dim (int): The dimension of the input tensor.\n            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.\n\n        Attributes:\n            eps (float): A small value added to the denominator for numerical stability.\n            weight (nn.Parameter): Learnable scaling parameter.\n\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        if elementwise_affine:\n            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))\n\n    def _norm(self, x):\n        \"\"\"\n        Apply the RMSNorm normalization to the input tensor.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The normalized tensor.\n\n        \"\"\"\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass through the RMSNorm layer.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The output tensor after applying RMSNorm.\n\n        \"\"\"\n        output = self._norm(x.float()).type_as(x)\n        if hasattr(self, \"weight\"):\n            output = output * self.weight\n        return output\n\n\ndef get_norm_layer(norm_layer):\n    \"\"\"\n    Get the normalization layer.\n\n    Args:\n        norm_layer (str): The type of normalization layer.\n\n    Returns:\n        norm_layer (nn.Module): The normalization layer.\n    \"\"\"\n    if norm_layer == \"layer\":\n        return nn.LayerNorm\n    elif norm_layer == \"rms\":\n        return RMSNorm\n    else:\n        raise NotImplementedError(f\"Norm layer {norm_layer} is not implemented\")\n\n\nclass FaceAdapter(nn.Module):\n    def __init__(\n        self,\n        hidden_dim: int,\n        heads_num: int,\n        qk_norm: bool = True,\n        qk_norm_type: str = \"rms\",\n        num_adapter_layers: int = 1,\n        dtype=None,\n        device=None,\n    ):\n\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n        self.hidden_size = hidden_dim\n        self.heads_num = heads_num\n        self.fuser_blocks = nn.ModuleList(\n            [\n                FaceBlock(\n                    self.hidden_size,\n                    self.heads_num,\n                    qk_norm=qk_norm,\n                    qk_norm_type=qk_norm_type,\n                    **factory_kwargs,\n                )\n                for _ in range(num_adapter_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        motion_embed: torch.Tensor,\n        idx: int,\n        freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,\n        freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)\n\n\n\nclass FaceBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        heads_num: int,\n        qk_norm: bool = True,\n        qk_norm_type: str = \"rms\",\n        qk_scale: float = None,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n\n        self.deterministic = False\n        self.hidden_size = hidden_size\n        self.heads_num = heads_num\n        head_dim = hidden_size // heads_num\n        self.scale = qk_scale or head_dim**-0.5\n       \n        self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)\n        self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)\n\n        self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)\n\n        qk_norm_layer = get_norm_layer(qk_norm_type)\n        self.q_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()\n        )\n        self.k_norm = (\n            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()\n        )\n\n        self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        motion_vec: torch.Tensor,\n        motion_mask: Optional[torch.Tensor] = None,\n        use_context_parallel=False,\n    ) -> torch.Tensor:\n        \n        B, T, N, C = motion_vec.shape\n        T_comp = T\n\n        x_motion = self.pre_norm_motion(motion_vec)\n        x_feat = self.pre_norm_feat(x)\n\n        kv = self.linear1_kv(x_motion)\n        q = self.linear1_q(x_feat)\n\n        k, v = rearrange(kv, \"B L N (K H D) -> K B L N H D\", K=2, H=self.heads_num)\n        q = rearrange(q, \"B S (H D) -> B S H D\", H=self.heads_num)\n\n        # Apply QK-Norm if needed.\n        q = self.q_norm(q).to(v)\n        k = self.k_norm(k).to(v)\n\n        k = rearrange(k, \"B L N H D -> (B L) N H D\")  \n        v = rearrange(v, \"B L N H D -> (B L) N H D\") \n\n        if use_context_parallel:\n            q = gather_forward(q, dim=1)\n\n        q = rearrange(q, \"B (L S) H D -> (B L) S H D\", L=T_comp)  \n        # Compute attention.\n        attn = attention(\n            q,\n            k,\n            v,\n            max_seqlen_q=q.shape[1],\n            batch_size=q.shape[0],\n        )\n\n        attn = rearrange(attn, \"(B L) S C -> B (L S) C\", L=T_comp)\n        if use_context_parallel:\n            attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]\n\n        output = self.linear2(attn)\n\n        if motion_mask is not None:\n            output = output * rearrange(motion_mask, \"B T H W -> B (T H W)\").unsqueeze(-1)\n\n        return output"
  },
  {
    "path": "wan/modules/animate/model_animate.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nimport types\nfrom copy import deepcopy\nfrom einops import  rearrange\nfrom typing import List\nimport numpy as np\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn as nn\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.loaders import PeftAdapterMixin\n\nfrom ...distributed.sequence_parallel import (\n    distributed_attention,\n    gather_forward,\n    get_rank,\n    get_world_size,\n)\n\n\nfrom ..model import (\n    Head,\n    WanAttentionBlock,\n    WanLayerNorm,\n    WanRMSNorm,\n    WanModel,\n    WanSelfAttention,\n    flash_attention,\n    rope_params,\n    sinusoidal_embedding_1d,\n    rope_apply\n)\n\nfrom .face_blocks import FaceEncoder, FaceAdapter\nfrom .motion_encoder import Generator\n\nclass HeadAnimate(Head):\n\n    def forward(self, x, e):\n        \"\"\"\n        Args:\n            x(Tensor): Shape [B, L1, C]\n            e(Tensor): Shape [B, L1, C]\n        \"\"\"\n        assert e.dtype == torch.float32\n        with amp.autocast(dtype=torch.float32):\n            e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)\n            x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))\n        return x\n\n\nclass WanAnimateSelfAttention(WanSelfAttention):\n\n    def forward(self, x, seq_lens, grid_sizes, freqs):\n        \"\"\"\n        Args:\n            x(Tensor): Shape [B, L, num_heads, C / num_heads]\n            seq_lens(Tensor): Shape [B]\n            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)\n            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]\n        \"\"\"\n        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n\n        # query, key, value function\n        def qkv_fn(x):\n            q = self.norm_q(self.q(x)).view(b, s, n, d)\n            k = self.norm_k(self.k(x)).view(b, s, n, d)\n            v = self.v(x).view(b, s, n, d)\n            return q, k, v\n\n        q, k, v = qkv_fn(x)\n\n        x = flash_attention(\n            q=rope_apply(q, grid_sizes, freqs),\n            k=rope_apply(k, grid_sizes, freqs),\n            v=v,\n            k_lens=seq_lens,\n            window_size=self.window_size)\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\nclass WanAnimateCrossAttention(WanSelfAttention):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        window_size=(-1, -1),\n        qk_norm=True,\n        eps=1e-6,\n        use_img_emb=True\n    ):\n        super().__init__(\n            dim,\n            num_heads,\n            window_size,\n            qk_norm,\n            eps\n        )\n        self.use_img_emb = use_img_emb\n\n        if use_img_emb:\n            self.k_img = nn.Linear(dim, dim)\n            self.v_img = nn.Linear(dim, dim)\n            self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n\n    def forward(self, x, context, context_lens):\n        \"\"\"\n        x:              [B, L1, C].\n        context:        [B, L2, C].\n        context_lens:   [B].\n        \"\"\"\n        if self.use_img_emb:\n            context_img = context[:, :257]\n            context = context[:, 257:]\n        else:\n            context = context\n\n        b, n, d = x.size(0), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.norm_q(self.q(x)).view(b, -1, n, d)\n        k = self.norm_k(self.k(context)).view(b, -1, n, d)\n        v = self.v(context).view(b, -1, n, d)\n\n        if self.use_img_emb:\n            k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)\n            v_img = self.v_img(context_img).view(b, -1, n, d)\n            img_x = flash_attention(q, k_img, v_img, k_lens=None)\n        # compute attention\n        x = flash_attention(q, k, v, k_lens=context_lens)\n\n        # output\n        x = x.flatten(2)\n\n        if self.use_img_emb:\n            img_x = img_x.flatten(2)\n            x = x + img_x\n\n        x = self.o(x)\n        return x\n\n\nclass WanAnimateAttentionBlock(nn.Module):\n    def __init__(self,\n                 dim,\n                 ffn_dim,\n                 num_heads,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 cross_attn_norm=True,\n                 eps=1e-6,\n                 use_img_emb=True):\n\n        super().__init__()\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n\n        # layers\n        self.norm1 = WanLayerNorm(dim, eps)\n        self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)\n        \n        self.norm3 = WanLayerNorm(\n            dim, eps, elementwise_affine=True\n        ) if cross_attn_norm else nn.Identity()\n            \n        self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)\n        self.norm2 = WanLayerNorm(dim, eps)\n        self.ffn = nn.Sequential(\n            nn.Linear(dim, ffn_dim),\n            nn.GELU(approximate='tanh'),\n            nn.Linear(ffn_dim, dim)\n        )\n\n        # modulation\n        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)\n\n    def forward(\n        self,\n        x,\n        e,\n        seq_lens,\n        grid_sizes,\n        freqs,\n        context,\n        context_lens,\n    ):\n        \"\"\"\n        Args:\n            x(Tensor): Shape [B, L, C]\n            e(Tensor): Shape [B, L1, 6, C]\n            seq_lens(Tensor): Shape [B], length of each sequence in batch\n            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)\n            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]\n        \"\"\"\n        assert e.dtype == torch.float32\n        with amp.autocast(dtype=torch.float32):\n            e = (self.modulation + e).chunk(6, dim=1)\n        assert e[0].dtype == torch.float32\n\n        # self-attention\n        y = self.self_attn(\n            self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs\n        )\n        with amp.autocast(dtype=torch.float32):\n            x = x + y * e[2]\n\n        # cross-attention & ffn function\n        def cross_attn_ffn(x, context, context_lens, e):\n            x = x + self.cross_attn(self.norm3(x), context, context_lens)\n            y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])\n            with amp.autocast(dtype=torch.float32):\n                x = x + y * e[5]\n            return x\n\n        x = cross_attn_ffn(x, context, context_lens, e)\n        return x\n\n\nclass MLPProj(torch.nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n\n        self.proj = torch.nn.Sequential(\n            torch.nn.LayerNorm(in_dim),\n            torch.nn.Linear(in_dim, in_dim),\n            torch.nn.GELU(),\n            torch.nn.Linear(in_dim, out_dim),\n            torch.nn.LayerNorm(out_dim),\n        )\n\n    def forward(self, image_embeds):\n        clip_extra_context_tokens = self.proj(image_embeds)\n        return clip_extra_context_tokens\n\nclass WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):\n    _no_split_modules = ['WanAttentionBlock']\n\n    @register_to_config\n    def __init__(self,\n                 patch_size=(1, 2, 2),\n                 text_len=512,\n                 in_dim=36,\n                 dim=5120,\n                 ffn_dim=13824,\n                 freq_dim=256,\n                 text_dim=4096,\n                 out_dim=16,\n                 num_heads=40,\n                 num_layers=40,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 cross_attn_norm=True,\n                 eps=1e-6,\n                 motion_encoder_dim=512,\n                 use_context_parallel=False,\n                 use_img_emb=True):\n\n        super().__init__()\n        self.patch_size = patch_size\n        self.text_len = text_len\n        self.in_dim = in_dim\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.freq_dim = freq_dim\n        self.text_dim = text_dim\n        self.out_dim = out_dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n        self.motion_encoder_dim = motion_encoder_dim\n        self.use_context_parallel = use_context_parallel\n        self.use_img_emb = use_img_emb\n\n        # embeddings\n        self.patch_embedding = nn.Conv3d(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size)\n\n        self.pose_patch_embedding = nn.Conv3d(\n            16, dim, kernel_size=patch_size, stride=patch_size\n        )\n\n        self.text_embedding = nn.Sequential(\n            nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),\n            nn.Linear(dim, dim))\n\n        self.time_embedding = nn.Sequential(\n            nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))\n\n        # blocks\n        self.blocks = nn.ModuleList([\n            WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,\n                              cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)\n        ])\n\n        # head\n        self.head = HeadAnimate(dim, out_dim, patch_size, eps)\n\n        # buffers (don't use register_buffer otherwise dtype will be changed in to())\n        assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0\n        d = dim // num_heads\n        self.freqs = torch.cat([\n            rope_params(1024, d - 4 * (d // 6)),\n            rope_params(1024, 2 * (d // 6)),\n            rope_params(1024, 2 * (d // 6))\n        ], dim=1)\n\n        self.img_emb = MLPProj(1280, dim)\n        \n        # initialize weights\n        self.init_weights()\n\n        self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)\n        self.face_adapter = FaceAdapter(\n            heads_num=self.num_heads,\n            hidden_dim=self.dim,\n            num_adapter_layers=self.num_layers // 5,\n        )\n\n        self.face_encoder = FaceEncoder(\n            in_dim=motion_encoder_dim,\n            hidden_dim=self.dim,\n            num_heads=4,\n        )\n\n    def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):\n        pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]\n        for x_, pose_latents_ in zip(x, pose_latents):\n            x_[:, :, 1:] += pose_latents_\n        \n        b,c,T,h,w = face_pixel_values.shape\n        face_pixel_values = rearrange(face_pixel_values, \"b c t h w -> (b t) c h w\")\n\n        encode_bs = 8\n        face_pixel_values_tmp = []\n        for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):\n            face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))\n\n        motion_vec = torch.cat(face_pixel_values_tmp)\n        \n        motion_vec = rearrange(motion_vec, \"(b t) c -> b t c\", t=T)\n        motion_vec = self.face_encoder(motion_vec)\n\n        B, L, H, C = motion_vec.shape\n        pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)\n        motion_vec = torch.cat([pad_face, motion_vec], dim=1)\n        return x, motion_vec\n\n\n    def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):\n        if block_idx % 5 == 0:\n            adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]\n            residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)\n            x = residual_out + x\n        return x\n\n\n    def forward(\n        self,\n        x,\n        t,\n        clip_fea,\n        context,\n        seq_len,\n        y=None,\n        pose_latents=None, \n        face_pixel_values=None\n    ):\n        # params\n        device = self.patch_embedding.weight.device\n        if self.freqs.device != device:\n            self.freqs = self.freqs.to(device)\n\n        if y is not None:\n            x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]\n\n        # embeddings\n        x = [self.patch_embedding(u.unsqueeze(0)) for u in x]\n        x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)\n\n        grid_sizes = torch.stack(\n            [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])\n        x = [u.flatten(2).transpose(1, 2) for u in x]\n        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)\n        assert seq_lens.max() <= seq_len\n        x = torch.cat([\n            torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],\n                      dim=1) for u in x\n        ])\n\n        # time embeddings\n        with amp.autocast(dtype=torch.float32):\n            e = self.time_embedding(\n                sinusoidal_embedding_1d(self.freq_dim, t).float()\n            )\n            e0 = self.time_projection(e).unflatten(1, (6, self.dim))\n            assert e.dtype == torch.float32 and e0.dtype == torch.float32\n\n        # context\n        context_lens = None\n        context = self.text_embedding(\n            torch.stack([\n                torch.cat(\n                    [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])\n                for u in context\n            ]))\n\n        if self.use_img_emb:\n            context_clip = self.img_emb(clip_fea) # bs x 257 x dim\n            context = torch.concat([context_clip, context], dim=1)\n\n        # arguments\n        kwargs = dict(\n            e=e0,\n            seq_lens=seq_lens,\n            grid_sizes=grid_sizes,\n            freqs=self.freqs,\n            context=context,\n            context_lens=context_lens)\n\n        if self.use_context_parallel:\n            x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]\n\n        for idx, block in enumerate(self.blocks):\n            x = block(x, **kwargs)\n            x = self.after_transformer_block(idx, x, motion_vec)\n\n        # head\n        x = self.head(x, e)\n\n        if self.use_context_parallel:\n            x = gather_forward(x, dim=1)\n\n        # unpatchify\n        x = self.unpatchify(x, grid_sizes)\n        return [u.float() for u in x]\n\n\n    def unpatchify(self, x, grid_sizes):\n        r\"\"\"\n        Reconstruct video tensors from patch embeddings.\n\n        Args:\n            x (List[Tensor]):\n                List of patchified features, each with shape [L, C_out * prod(patch_size)]\n            grid_sizes (Tensor):\n                Original spatial-temporal grid dimensions before patching,\n                    shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)\n\n        Returns:\n            List[Tensor]:\n                Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]\n        \"\"\"\n\n        c = self.out_dim\n        out = []\n        for u, v in zip(x, grid_sizes.tolist()):\n            u = u[:math.prod(v)].view(*v, *self.patch_size, c)\n            u = torch.einsum('fhwpqrc->cfphqwr', u)\n            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])\n            out.append(u)\n        return out\n\n    def init_weights(self):\n        r\"\"\"\n        Initialize model parameters using Xavier initialization.\n        \"\"\"\n\n        # basic init\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.xavier_uniform_(m.weight)\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n\n        # init embeddings\n        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))\n        for m in self.text_embedding.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=.02)\n        for m in self.time_embedding.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=.02)\n\n        # init output layer\n        nn.init.zeros_(self.head.head.weight)\n"
  },
  {
    "path": "wan/modules/animate/motion_encoder.py",
    "content": "# Modified from ``https://github.com/wyhsirius/LIA``\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\n\ndef custom_qr(input_tensor):\n    original_dtype = input_tensor.dtype\n    if original_dtype == torch.bfloat16:\n        q, r = torch.linalg.qr(input_tensor.to(torch.float32))\n        return q.to(original_dtype), r.to(original_dtype)\n    return torch.linalg.qr(input_tensor)\n\ndef fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):\n\treturn F.leaky_relu(input + bias, negative_slope) * scale\n\n\ndef upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):\n\t_, minor, in_h, in_w = input.shape\n\tkernel_h, kernel_w = kernel.shape\n\n\tout = input.view(-1, minor, in_h, 1, in_w, 1)\n\tout = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])\n\tout = out.view(-1, minor, in_h * up_y, in_w * up_x)\n\n\tout = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])\n\tout = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),\n\t\t  max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]\n\n\tout = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])\n\tw = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\n\tout = F.conv2d(out, w)\n\tout = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\n\t\t\t\t\t  in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )\n\treturn out[:, :, ::down_y, ::down_x]\n\n\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\n\treturn upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])\n\n\ndef make_kernel(k):\n\tk = torch.tensor(k, dtype=torch.float32)\n\tif k.ndim == 1:\n\t\tk = k[None, :] * k[:, None]\n\tk /= k.sum()\n\treturn k\n\n\nclass FusedLeakyReLU(nn.Module):\n\tdef __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):\n\t\tsuper().__init__()\n\t\tself.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))\n\t\tself.negative_slope = negative_slope\n\t\tself.scale = scale\n\n\tdef forward(self, input):\n\t\tout = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)\n\t\treturn out\n\n\nclass Blur(nn.Module):\n\tdef __init__(self, kernel, pad, upsample_factor=1):\n\t\tsuper().__init__()\n\n\t\tkernel = make_kernel(kernel)\n\n\t\tif upsample_factor > 1:\n\t\t\tkernel = kernel * (upsample_factor ** 2)\n\n\t\tself.register_buffer('kernel', kernel)\n\n\t\tself.pad = pad\n\n\tdef forward(self, input):\n\t\treturn upfirdn2d(input, self.kernel, pad=self.pad)\n\n\nclass ScaledLeakyReLU(nn.Module):\n\tdef __init__(self, negative_slope=0.2):\n\t\tsuper().__init__()\n\n\t\tself.negative_slope = negative_slope\n\n\tdef forward(self, input):\n\t\treturn F.leaky_relu(input, negative_slope=self.negative_slope)\n\n\nclass EqualConv2d(nn.Module):\n\tdef __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):\n\t\tsuper().__init__()\n\n\t\tself.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))\n\t\tself.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)\n\n\t\tself.stride = stride\n\t\tself.padding = padding\n\n\t\tif bias:\n\t\t\tself.bias = nn.Parameter(torch.zeros(out_channel))\n\t\telse:\n\t\t\tself.bias = None\n\n\tdef forward(self, input):\n\n\t\treturn F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)\n\n\tdef __repr__(self):\n\t\treturn (\n\t\t\tf'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'\n\t\t\tf' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'\n\t\t)\n\n\nclass EqualLinear(nn.Module):\n\tdef __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):\n\t\tsuper().__init__()\n\n\t\tself.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))\n\n\t\tif bias:\n\t\t\tself.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))\n\t\telse:\n\t\t\tself.bias = None\n\n\t\tself.activation = activation\n\n\t\tself.scale = (1 / math.sqrt(in_dim)) * lr_mul\n\t\tself.lr_mul = lr_mul\n\n\tdef forward(self, input):\n\n\t\tif self.activation:\n\t\t\tout = F.linear(input, self.weight * self.scale)\n\t\t\tout = fused_leaky_relu(out, self.bias * self.lr_mul)\n\t\telse:\n\t\t\tout = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)\n\n\t\treturn out\n\n\tdef __repr__(self):\n\t\treturn (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')\n\n\nclass ConvLayer(nn.Sequential):\n\tdef __init__(\n\t\t\tself,\n\t\t\tin_channel,\n\t\t\tout_channel,\n\t\t\tkernel_size,\n\t\t\tdownsample=False,\n\t\t\tblur_kernel=[1, 3, 3, 1],\n\t\t\tbias=True,\n\t\t\tactivate=True,\n\t):\n\t\tlayers = []\n\n\t\tif downsample:\n\t\t\tfactor = 2\n\t\t\tp = (len(blur_kernel) - factor) + (kernel_size - 1)\n\t\t\tpad0 = (p + 1) // 2\n\t\t\tpad1 = p // 2\n\n\t\t\tlayers.append(Blur(blur_kernel, pad=(pad0, pad1)))\n\n\t\t\tstride = 2\n\t\t\tself.padding = 0\n\n\t\telse:\n\t\t\tstride = 1\n\t\t\tself.padding = kernel_size // 2\n\n\t\tlayers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,\n\t\t\t\t\t\t\t\t  bias=bias and not activate))\n\n\t\tif activate:\n\t\t\tif bias:\n\t\t\t\tlayers.append(FusedLeakyReLU(out_channel))\n\t\t\telse:\n\t\t\t\tlayers.append(ScaledLeakyReLU(0.2))\n\n\t\tsuper().__init__(*layers)\n\n\nclass ResBlock(nn.Module):\n\tdef __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):\n\t\tsuper().__init__()\n\n\t\tself.conv1 = ConvLayer(in_channel, in_channel, 3)\n\t\tself.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)\n\n\t\tself.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)\n\n\tdef forward(self, input):\n\t\tout = self.conv1(input)\n\t\tout = self.conv2(out)\n\n\t\tskip = self.skip(input)\n\t\tout = (out + skip) / math.sqrt(2)\n\n\t\treturn out\n\n\nclass EncoderApp(nn.Module):\n\tdef __init__(self, size, w_dim=512):\n\t\tsuper(EncoderApp, self).__init__()\n\n\t\tchannels = {\n\t\t\t4: 512,\n\t\t\t8: 512,\n\t\t\t16: 512,\n\t\t\t32: 512,\n\t\t\t64: 256,\n\t\t\t128: 128,\n\t\t\t256: 64,\n\t\t\t512: 32,\n\t\t\t1024: 16\n\t\t}\n\n\t\tself.w_dim = w_dim\n\t\tlog_size = int(math.log(size, 2))\n\n\t\tself.convs = nn.ModuleList()\n\t\tself.convs.append(ConvLayer(3, channels[size], 1))\n\n\t\tin_channel = channels[size]\n\t\tfor i in range(log_size, 2, -1):\n\t\t\tout_channel = channels[2 ** (i - 1)]\n\t\t\tself.convs.append(ResBlock(in_channel, out_channel))\n\t\t\tin_channel = out_channel\n\n\t\tself.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))\n\n\tdef forward(self, x):\n\n\t\tres = []\n\t\th = x\n\t\tfor conv in self.convs:\n\t\t\th = conv(h)\n\t\t\tres.append(h)\n\n\t\treturn res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]\n\n\nclass Encoder(nn.Module):\n\tdef __init__(self, size, dim=512, dim_motion=20):\n\t\tsuper(Encoder, self).__init__()\n\n\t\t# appearance netmork\n\t\tself.net_app = EncoderApp(size, dim)\n\n\t\t# motion network\n\t\tfc = [EqualLinear(dim, dim)]\n\t\tfor i in range(3):\n\t\t\tfc.append(EqualLinear(dim, dim))\n\n\t\tfc.append(EqualLinear(dim, dim_motion))\n\t\tself.fc = nn.Sequential(*fc)\n\n\tdef enc_app(self, x):\n\t\th_source = self.net_app(x)\n\t\treturn h_source\n\n\tdef enc_motion(self, x):\n\t\th, _ = self.net_app(x)\n\t\th_motion = self.fc(h)\n\t\treturn h_motion\n\n\nclass Direction(nn.Module):\n    def __init__(self, motion_dim):\n        super(Direction, self).__init__()\n        self.weight = nn.Parameter(torch.randn(512, motion_dim))\n\n    def forward(self, input):\n\n        weight = self.weight + 1e-8\n        Q, R = custom_qr(weight)\n        if input is None:\n            return Q\n        else:\n            input_diag = torch.diag_embed(input)  # alpha, diagonal matrix\n            out = torch.matmul(input_diag, Q.T)\n            out = torch.sum(out, dim=1)\n            return out\n\n\nclass Synthesis(nn.Module):\n    def __init__(self, motion_dim):\n        super(Synthesis, self).__init__()\n        self.direction = Direction(motion_dim)\n\n\nclass Generator(nn.Module):\n    def __init__(self, size, style_dim=512, motion_dim=20):\n        super().__init__()\n\n        self.enc = Encoder(size, style_dim, motion_dim)\n        self.dec = Synthesis(motion_dim)\n\n    def get_motion(self, img):\n        #motion_feat = self.enc.enc_motion(img)\n        motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)\n        with torch.cuda.amp.autocast(dtype=torch.float32):\n            motion = self.dec.direction(motion_feat)\n        return motion"
  },
  {
    "path": "wan/modules/animate/preprocess/UserGuider.md",
    "content": "# Wan-animate Preprocessing User Guider\n\n## 1. Introductions\n\n\nWan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.\n\n### 1.1 Animation Mode\n\nIn this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.\n\n - A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.\n\n - **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.\n\n - Community contributions to improve on this feature are welcome.\n\n### 1.2 Replacement Mode\n\n - Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.\n\n - **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.\n\n - A simplified version for extracting the character's mask is also provided.\n - **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.\n\n---\n\n## 2. Preprocessing Instructions and Recommendations\n\n### 2.1 Basic Usage\n\n- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:\n```\n    /path/to/your/ckpt_path/\n    ├── det/\n    │ └── yolov10m.onnx\n    ├── pose2d/\n    │ └── vitpose_h_wholebody.onnx\n    ├── sam2/\n    │ └── sam2_hiera_large.pt\n    └── FLUX.1-Kontext-dev/\n```\n- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.\n\n- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.\n\n- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.\n\n- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.\n\n---\n\n### 2.2 Animation Mode\n\n- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.\n\n- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.\n\n- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.\n\n---\n\n### 2.3 Replacement Mode\n\n- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.\n- `iterations` and `k` can make the mask larger, covering more area.\n- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.\n\n- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data."
  },
  {
    "path": "wan/modules/animate/preprocess/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .process_pipepline import ProcessPipeline\nfrom .video_predictor import SAM2VideoPredictor"
  },
  {
    "path": "wan/modules/animate/preprocess/human_visualization.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nimport time\nimport math\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom typing import Dict, List\nimport random\nfrom pose2d_utils import AAPoseMeta\n\n\ndef draw_handpose(canvas, keypoints, hand_score_th=0.6):\n    \"\"\"\n    Draw keypoints and connections representing hand pose on a given canvas.\n\n    Args:\n        canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.\n        keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn\n                                          or None if no keypoints are present.\n\n    Returns:\n        np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.\n\n    Note:\n        The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.\n    \"\"\"\n    eps = 0.01\n\n    H, W, C = canvas.shape\n    stickwidth = max(int(min(H, W) / 200), 1)\n\n    edges = [\n        [0, 1],\n        [1, 2],\n        [2, 3],\n        [3, 4],\n        [0, 5],\n        [5, 6],\n        [6, 7],\n        [7, 8],\n        [0, 9],\n        [9, 10],\n        [10, 11],\n        [11, 12],\n        [0, 13],\n        [13, 14],\n        [14, 15],\n        [15, 16],\n        [0, 17],\n        [17, 18],\n        [18, 19],\n        [19, 20],\n    ]\n\n    for ie, (e1, e2) in enumerate(edges):\n        k1 = keypoints[e1]\n        k2 = keypoints[e2]\n        if k1 is None or k2 is None:\n            continue\n        if k1[2] < hand_score_th or k2[2] < hand_score_th:\n            continue\n\n        x1 = int(k1[0])\n        y1 = int(k1[1])\n        x2 = int(k2[0])\n        y2 = int(k2[1])\n        if x1 > eps and y1 > eps and x2 > eps and y2 > eps:\n            cv2.line(\n                canvas,\n                (x1, y1),\n                (x2, y2),\n                matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,\n                thickness=stickwidth,\n            )\n\n    for keypoint in keypoints:\n\n        if keypoint is None:\n            continue\n        if keypoint[2] < hand_score_th:\n            continue\n\n        x, y = keypoint[0], keypoint[1]\n        x = int(x)\n        y = int(y)\n        if x > eps and y > eps:\n            cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)\n    return canvas\n\n\ndef draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_score_th=0.6):\n    \"\"\"\n    Draw keypoints and connections representing hand pose on a given canvas.\n\n    Args:\n        canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.\n        keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn\n                                          or None if no keypoints are present.\n\n    Returns:\n        np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.\n\n    Note:\n        The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.\n    \"\"\"\n    eps = 0.01\n\n    H, W, C = canvas.shape\n    if stickwidth_type == 'v1':\n        stickwidth = max(int(min(H, W) / 200), 1)\n    elif stickwidth_type == 'v2':\n        stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1)\n\n    edges = [\n        [0, 1],\n        [1, 2],\n        [2, 3],\n        [3, 4],\n        [0, 5],\n        [5, 6],\n        [6, 7],\n        [7, 8],\n        [0, 9],\n        [9, 10],\n        [10, 11],\n        [11, 12],\n        [0, 13],\n        [13, 14],\n        [14, 15],\n        [15, 16],\n        [0, 17],\n        [17, 18],\n        [18, 19],\n        [19, 20],\n    ]\n\n    for ie, (e1, e2) in enumerate(edges):\n        k1 = keypoints[e1]\n        k2 = keypoints[e2]\n        if k1 is None or k2 is None:\n            continue\n        if k1[2] < hand_score_th or k2[2] < hand_score_th:\n            continue\n\n        x1 = int(k1[0])\n        y1 = int(k1[1])\n        x2 = int(k2[0])\n        y2 = int(k2[1])\n        if x1 > eps and y1 > eps and x2 > eps and y2 > eps:\n            cv2.line(\n                canvas,\n                (x1, y1),\n                (x2, y2),\n                matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,\n                thickness=stickwidth,\n            )\n\n    for keypoint in keypoints:\n\n        if keypoint is None:\n            continue\n        if keypoint[2] < hand_score_th:\n            continue\n\n        x, y = keypoint[0], keypoint[1]\n        x = int(x)\n        y = int(y)\n        if x > eps and y > eps:\n            cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)\n    return canvas\n\n\ndef draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6):\n    H, W, C = img.shape\n    stickwidth = max(int(min(H, W) / 200), 1)\n\n    if keypoint1[-1] < threshold or keypoint2[-1] < threshold:\n        return img\n\n    Y = np.array([keypoint1[0], keypoint2[0]])\n    X = np.array([keypoint1[1], keypoint2[1]])\n    mX = np.mean(X)\n    mY = np.mean(Y)\n    length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n    angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n    polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)\n    cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])\n    return img\n\n\ndef split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]:\n    \"\"\"Convert the 133 keypoints from pose2d to body and hands keypoints.\n\n    Args:\n        kp2ds (np.ndarray): [133, 2]\n\n    Returns:\n        List[np.ndarray]: _description_\n    \"\"\"\n    kp2ds_body = (\n        kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]]\n        + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]\n    ) / 2\n    kp2ds_lhand = kp2ds[91:112]\n    kp2ds_rhand = kp2ds[112:133]\n    return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()\n\n\ndef draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):\n    kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)\n    kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)\n    kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)\n    pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)\n    return pose_img\n\ndef draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type='v2', draw_hand=True, draw_head=True):\n    kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)\n    kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)\n    kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)\n    pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand,\n                               stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head)\n    return pose_img\n\ndef draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200):\n    kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1)\n    kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)\n    kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)\n    pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False)\n    return pose_img\n\n\ndef draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True):\n    kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)\n    # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)\n    # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)\n    pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)\n    return pose_img\n\n\ndef draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False):\n    kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)\n    # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)\n    # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)\n    pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand)\n    return pose_img\n\n\ndef gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200):\n\n    return \n\n\ndef draw_M(\n    img,\n    kp2ds,\n    threshold=0.6,\n    data_to_json=None,\n    idx=-1,\n    kp2ds_lhand=None,\n    kp2ds_rhand=None,\n    draw_hand=False,\n    stick_width_norm=200,\n    draw_head=True\n):\n    \"\"\"\n    Draw keypoints and connections representing hand pose on a given canvas.\n\n    Args:\n        canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.\n        keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn\n                                          or None if no keypoints are present.\n\n    Returns:\n        np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.\n\n    Note:\n        The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.\n    \"\"\"\n\n    new_kep_list = [\n        \"Nose\",\n        \"Neck\",\n        \"RShoulder\",\n        \"RElbow\",\n        \"RWrist\",  # No.4\n        \"LShoulder\",\n        \"LElbow\",\n        \"LWrist\",  # No.7\n        \"RHip\",\n        \"RKnee\",\n        \"RAnkle\",  # No.10\n        \"LHip\",\n        \"LKnee\",\n        \"LAnkle\",  # No.13\n        \"REye\",\n        \"LEye\",\n        \"REar\",\n        \"LEar\",\n        \"LToe\",\n        \"RToe\",\n    ]\n    # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \\\n    #              kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2\n    kp2ds = kp2ds.copy()\n    # import ipdb; ipdb.set_trace()\n    kp2ds[[1,2,3,4,5,6,7,8,9,10,11,12,13,18,19], 2] = 0\n    if not draw_head:\n        kp2ds[[0,14,15,16,17], 2] = 0\n    kp2ds_body = kp2ds\n    # kp2ds_body = kp2ds_body[:18]\n\n    # kp2ds_lhand = kp2ds.copy()[91:112]\n    # kp2ds_rhand = kp2ds.copy()[112:133]\n\n    limbSeq = [\n        # [2, 3],\n        # [2, 6],  # shoulders\n        # [3, 4],\n        # [4, 5],  # left arm\n        # [6, 7],\n        # [7, 8],  # right arm\n        # [2, 9],\n        # [9, 10],\n        # [10, 11],  # right leg\n        # [2, 12],\n        # [12, 13],\n        # [13, 14],  # left leg\n        # [2, 1],\n        [1, 15],\n        [15, 17],\n        [1, 16],\n        [16, 18],  # face (nose, eyes, ears)\n        # [14, 19],\n        # [11, 20],  # foot\n    ]\n\n    colors = [\n        # [255, 0, 0],\n        # [255, 85, 0],\n        # [255, 170, 0],\n        # [255, 255, 0],\n        # [170, 255, 0],\n        # [85, 255, 0],\n        # [0, 255, 0],\n        # [0, 255, 85],\n        # [0, 255, 170],\n        # [0, 255, 255],\n        # [0, 170, 255],\n        # [0, 85, 255],\n        # [0, 0, 255],\n        # [85, 0, 255],\n        [170, 0, 255],\n        [255, 0, 255],\n        [255, 0, 170],\n        [255, 0, 85],\n        # foot\n        # [200, 200, 0],\n        # [100, 100, 0],\n    ]\n\n    H, W, C = img.shape\n    stickwidth = max(int(min(H, W) / stick_width_norm), 1)\n\n    for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):\n        keypoint1 = kp2ds_body[k1_index - 1]\n        keypoint2 = kp2ds_body[k2_index - 1]\n\n        if keypoint1[-1] < threshold or keypoint2[-1] < threshold:\n            continue\n\n        Y = np.array([keypoint1[0], keypoint2[0]])\n        X = np.array([keypoint1[1], keypoint2[1]])\n        mX = np.mean(X)\n        mY = np.mean(Y)\n        length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n        angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n        polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)\n        cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])\n\n    for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):\n        if keypoint[-1] < threshold:\n            continue\n        x, y = keypoint[0], keypoint[1]\n        # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)\n        cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)\n\n    if draw_hand:\n        img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)\n        img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)\n\n    kp2ds_body[:, 0] /= W\n    kp2ds_body[:, 1] /= H\n\n    if data_to_json is not None:\n        if idx == -1:\n            data_to_json.append(\n                {\n                    \"image_id\": \"frame_{:05d}.jpg\".format(len(data_to_json) + 1),\n                    \"height\": H,\n                    \"width\": W,\n                    \"category_id\": 1,\n                    \"keypoints_body\": kp2ds_body.tolist(),\n                    \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                    \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n                }\n            )\n        else:\n            data_to_json[idx] = {\n                \"image_id\": \"frame_{:05d}.jpg\".format(idx + 1),\n                \"height\": H,\n                \"width\": W,\n                \"category_id\": 1,\n                \"keypoints_body\": kp2ds_body.tolist(),\n                \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n            }\n    return img\n\n\ndef draw_nose(\n    img,\n    kp2ds,\n    threshold=0.6,\n    data_to_json=None,\n    idx=-1,\n    kp2ds_lhand=None,\n    kp2ds_rhand=None,\n    draw_hand=False,\n    stick_width_norm=200,\n):\n    \"\"\"\n    Draw keypoints and connections representing hand pose on a given canvas.\n\n    Args:\n        canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.\n        keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn\n                                          or None if no keypoints are present.\n\n    Returns:\n        np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.\n\n    Note:\n        The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.\n    \"\"\"\n\n    new_kep_list = [\n        \"Nose\",\n        \"Neck\",\n        \"RShoulder\",\n        \"RElbow\",\n        \"RWrist\",  # No.4\n        \"LShoulder\",\n        \"LElbow\",\n        \"LWrist\",  # No.7\n        \"RHip\",\n        \"RKnee\",\n        \"RAnkle\",  # No.10\n        \"LHip\",\n        \"LKnee\",\n        \"LAnkle\",  # No.13\n        \"REye\",\n        \"LEye\",\n        \"REar\",\n        \"LEar\",\n        \"LToe\",\n        \"RToe\",\n    ]\n    # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \\\n    #              kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2\n    kp2ds = kp2ds.copy()\n    kp2ds[1:, 2] = 0\n    # kp2ds[0, 2] = 1\n    kp2ds_body = kp2ds\n    # kp2ds_body = kp2ds_body[:18]\n\n    # kp2ds_lhand = kp2ds.copy()[91:112]\n    # kp2ds_rhand = kp2ds.copy()[112:133]\n\n    limbSeq = [\n        # [2, 3],\n        # [2, 6],  # shoulders\n        # [3, 4],\n        # [4, 5],  # left arm\n        # [6, 7],\n        # [7, 8],  # right arm\n        # [2, 9],\n        # [9, 10],\n        # [10, 11],  # right leg\n        # [2, 12],\n        # [12, 13],\n        # [13, 14],  # left leg\n        # [2, 1],\n        [1, 15],\n        [15, 17],\n        [1, 16],\n        [16, 18],  # face (nose, eyes, ears)\n        # [14, 19],\n        # [11, 20],  # foot\n    ]\n\n    colors = [\n        # [255, 0, 0],\n        # [255, 85, 0],\n        # [255, 170, 0],\n        # [255, 255, 0],\n        # [170, 255, 0],\n        # [85, 255, 0],\n        # [0, 255, 0],\n        # [0, 255, 85],\n        # [0, 255, 170],\n        # [0, 255, 255],\n        # [0, 170, 255],\n        # [0, 85, 255],\n        # [0, 0, 255],\n        # [85, 0, 255],\n        [170, 0, 255],\n        # [255, 0, 255],\n        # [255, 0, 170],\n        # [255, 0, 85],\n        # foot\n        # [200, 200, 0],\n        # [100, 100, 0],\n    ]\n\n    H, W, C = img.shape\n    stickwidth = max(int(min(H, W) / stick_width_norm), 1)\n\n    # for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):\n    #     keypoint1 = kp2ds_body[k1_index - 1]\n    #     keypoint2 = kp2ds_body[k2_index - 1]\n\n    #     if keypoint1[-1] < threshold or keypoint2[-1] < threshold:\n    #         continue\n\n    #     Y = np.array([keypoint1[0], keypoint2[0]])\n    #     X = np.array([keypoint1[1], keypoint2[1]])\n    #     mX = np.mean(X)\n    #     mY = np.mean(Y)\n    #     length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n    #     angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n    #     polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)\n    #     cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])\n\n    for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):\n        if keypoint[-1] < threshold:\n            continue\n        x, y = keypoint[0], keypoint[1]\n        # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)\n        cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)\n\n    if draw_hand:\n        img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)\n        img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)\n\n    kp2ds_body[:, 0] /= W\n    kp2ds_body[:, 1] /= H\n\n    if data_to_json is not None:\n        if idx == -1:\n            data_to_json.append(\n                {\n                    \"image_id\": \"frame_{:05d}.jpg\".format(len(data_to_json) + 1),\n                    \"height\": H,\n                    \"width\": W,\n                    \"category_id\": 1,\n                    \"keypoints_body\": kp2ds_body.tolist(),\n                    \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                    \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n                }\n            )\n        else:\n            data_to_json[idx] = {\n                \"image_id\": \"frame_{:05d}.jpg\".format(idx + 1),\n                \"height\": H,\n                \"width\": W,\n                \"category_id\": 1,\n                \"keypoints_body\": kp2ds_body.tolist(),\n                \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n            }\n    return img\n\n\ndef draw_aapose(\n    img,\n    kp2ds,\n    threshold=0.6,\n    data_to_json=None,\n    idx=-1,\n    kp2ds_lhand=None,\n    kp2ds_rhand=None,\n    draw_hand=False,\n    stick_width_norm=200,\n    draw_head=True\n):\n    \"\"\"\n    Draw keypoints and connections representing hand pose on a given canvas.\n\n    Args:\n        canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.\n        keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn\n                                          or None if no keypoints are present.\n\n    Returns:\n        np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.\n\n    Note:\n        The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.\n    \"\"\"\n\n    new_kep_list = [\n        \"Nose\",\n        \"Neck\",\n        \"RShoulder\",\n        \"RElbow\",\n        \"RWrist\",  # No.4\n        \"LShoulder\",\n        \"LElbow\",\n        \"LWrist\",  # No.7\n        \"RHip\",\n        \"RKnee\",\n        \"RAnkle\",  # No.10\n        \"LHip\",\n        \"LKnee\",\n        \"LAnkle\",  # No.13\n        \"REye\",\n        \"LEye\",\n        \"REar\",\n        \"LEar\",\n        \"LToe\",\n        \"RToe\",\n    ]\n    # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \\\n    #              kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2\n    kp2ds = kp2ds.copy()\n    if not draw_head:\n        kp2ds[[0,14,15,16,17], 2] = 0\n    kp2ds_body = kp2ds\n\n    # kp2ds_lhand = kp2ds.copy()[91:112]\n    # kp2ds_rhand = kp2ds.copy()[112:133]\n\n    limbSeq = [\n        [2, 3],\n        [2, 6],  # shoulders\n        [3, 4],\n        [4, 5],  # left arm\n        [6, 7],\n        [7, 8],  # right arm\n        [2, 9],\n        [9, 10],\n        [10, 11],  # right leg\n        [2, 12],\n        [12, 13],\n        [13, 14],  # left leg\n        [2, 1],\n        [1, 15],\n        [15, 17],\n        [1, 16],\n        [16, 18],  # face (nose, eyes, ears)\n        [14, 19],\n        [11, 20],  # foot\n    ]\n\n    colors = [\n        [255, 0, 0],\n        [255, 85, 0],\n        [255, 170, 0],\n        [255, 255, 0],\n        [170, 255, 0],\n        [85, 255, 0],\n        [0, 255, 0],\n        [0, 255, 85],\n        [0, 255, 170],\n        [0, 255, 255],\n        [0, 170, 255],\n        [0, 85, 255],\n        [0, 0, 255],\n        [85, 0, 255],\n        [170, 0, 255],\n        [255, 0, 255],\n        [255, 0, 170],\n        [255, 0, 85],\n        # foot\n        [200, 200, 0],\n        [100, 100, 0],\n    ]\n\n    H, W, C = img.shape\n    stickwidth = max(int(min(H, W) / stick_width_norm), 1)\n\n    for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):\n        keypoint1 = kp2ds_body[k1_index - 1]\n        keypoint2 = kp2ds_body[k2_index - 1]\n\n        if keypoint1[-1] < threshold or keypoint2[-1] < threshold:\n            continue\n\n        Y = np.array([keypoint1[0], keypoint2[0]])\n        X = np.array([keypoint1[1], keypoint2[1]])\n        mX = np.mean(X)\n        mY = np.mean(Y)\n        length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n        angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n        polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)\n        cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])\n\n    for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):\n        if keypoint[-1] < threshold:\n            continue\n        x, y = keypoint[0], keypoint[1]\n        # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)\n        cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)\n\n    if draw_hand:\n        img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)\n        img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)\n\n    kp2ds_body[:, 0] /= W\n    kp2ds_body[:, 1] /= H\n\n    if data_to_json is not None:\n        if idx == -1:\n            data_to_json.append(\n                {\n                    \"image_id\": \"frame_{:05d}.jpg\".format(len(data_to_json) + 1),\n                    \"height\": H,\n                    \"width\": W,\n                    \"category_id\": 1,\n                    \"keypoints_body\": kp2ds_body.tolist(),\n                    \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                    \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n                }\n            )\n        else:\n            data_to_json[idx] = {\n                \"image_id\": \"frame_{:05d}.jpg\".format(idx + 1),\n                \"height\": H,\n                \"width\": W,\n                \"category_id\": 1,\n                \"keypoints_body\": kp2ds_body.tolist(),\n                \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n            }\n    return img\n\n\ndef draw_aapose_new(\n    img,\n    kp2ds,\n    threshold=0.6,\n    data_to_json=None,\n    idx=-1,\n    kp2ds_lhand=None,\n    kp2ds_rhand=None,\n    draw_hand=False,\n    stickwidth_type='v2',\n    draw_head=True\n):\n    \"\"\"\n    Draw keypoints and connections representing hand pose on a given canvas.\n\n    Args:\n        canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.\n        keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn\n                                          or None if no keypoints are present.\n\n    Returns:\n        np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.\n\n    Note:\n        The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.\n    \"\"\"\n\n    new_kep_list = [\n        \"Nose\",\n        \"Neck\",\n        \"RShoulder\",\n        \"RElbow\",\n        \"RWrist\",  # No.4\n        \"LShoulder\",\n        \"LElbow\",\n        \"LWrist\",  # No.7\n        \"RHip\",\n        \"RKnee\",\n        \"RAnkle\",  # No.10\n        \"LHip\",\n        \"LKnee\",\n        \"LAnkle\",  # No.13\n        \"REye\",\n        \"LEye\",\n        \"REar\",\n        \"LEar\",\n        \"LToe\",\n        \"RToe\",\n    ]\n    # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \\\n    #              kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2\n    kp2ds = kp2ds.copy()\n    if not draw_head:\n        kp2ds[[0,14,15,16,17], 2] = 0\n    kp2ds_body = kp2ds\n\n    # kp2ds_lhand = kp2ds.copy()[91:112]\n    # kp2ds_rhand = kp2ds.copy()[112:133]\n\n    limbSeq = [\n        [2, 3],\n        [2, 6],  # shoulders\n        [3, 4],\n        [4, 5],  # left arm\n        [6, 7],\n        [7, 8],  # right arm\n        [2, 9],\n        [9, 10],\n        [10, 11],  # right leg\n        [2, 12],\n        [12, 13],\n        [13, 14],  # left leg\n        [2, 1],\n        [1, 15],\n        [15, 17],\n        [1, 16],\n        [16, 18],  # face (nose, eyes, ears)\n        [14, 19],\n        [11, 20],  # foot\n    ]\n\n    colors = [\n        [255, 0, 0],\n        [255, 85, 0],\n        [255, 170, 0],\n        [255, 255, 0],\n        [170, 255, 0],\n        [85, 255, 0],\n        [0, 255, 0],\n        [0, 255, 85],\n        [0, 255, 170],\n        [0, 255, 255],\n        [0, 170, 255],\n        [0, 85, 255],\n        [0, 0, 255],\n        [85, 0, 255],\n        [170, 0, 255],\n        [255, 0, 255],\n        [255, 0, 170],\n        [255, 0, 85],\n        # foot\n        [200, 200, 0],\n        [100, 100, 0],\n    ]\n\n    H, W, C = img.shape\n    H, W, C = img.shape\n\n    if stickwidth_type == 'v1':\n        stickwidth = max(int(min(H, W) / 200), 1)\n    elif stickwidth_type == 'v2':\n        stickwidth = max(int(min(H, W) / 200) - 1, 1)\n    else:\n        raise\n\n    for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):\n        keypoint1 = kp2ds_body[k1_index - 1]\n        keypoint2 = kp2ds_body[k2_index - 1]\n\n        if keypoint1[-1] < threshold or keypoint2[-1] < threshold:\n            continue\n\n        Y = np.array([keypoint1[0], keypoint2[0]])\n        X = np.array([keypoint1[1], keypoint2[1]])\n        mX = np.mean(X)\n        mY = np.mean(Y)\n        length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n        angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n        polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)\n        cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])\n\n    for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):\n        if keypoint[-1] < threshold:\n            continue\n        x, y = keypoint[0], keypoint[1]\n        # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)\n        cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)\n\n    if draw_hand:\n        img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)\n        img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)\n\n    kp2ds_body[:, 0] /= W\n    kp2ds_body[:, 1] /= H\n\n    if data_to_json is not None:\n        if idx == -1:\n            data_to_json.append(\n                {\n                    \"image_id\": \"frame_{:05d}.jpg\".format(len(data_to_json) + 1),\n                    \"height\": H,\n                    \"width\": W,\n                    \"category_id\": 1,\n                    \"keypoints_body\": kp2ds_body.tolist(),\n                    \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                    \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n                }\n            )\n        else:\n            data_to_json[idx] = {\n                \"image_id\": \"frame_{:05d}.jpg\".format(idx + 1),\n                \"height\": H,\n                \"width\": W,\n                \"category_id\": 1,\n                \"keypoints_body\": kp2ds_body.tolist(),\n                \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n                \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n            }\n    return img\n\n\ndef draw_bbox(img, bbox, color=(255, 0, 0)):\n    img = load_image(img)\n    bbox = [int(bbox_tmp) for bbox_tmp in bbox]\n    cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)\n    return img\n\n\ndef draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False):\n    img = load_image(img, reverse)\n\n    if skeleton is not None:\n        if skeleton == \"coco17\":\n            skeleton_list = [\n                [6, 8],\n                [8, 10],\n                [5, 7],\n                [7, 9],\n                [11, 13],\n                [13, 15],\n                [12, 14],\n                [14, 16],\n                [5, 6],\n                [6, 12],\n                [12, 11],\n                [11, 5],\n            ]\n            color_list = [\n                (255, 0, 0),\n                (0, 255, 0),\n                (0, 0, 255),\n                (255, 255, 0),\n                (255, 0, 255),\n                (0, 255, 255),\n            ]\n        elif skeleton == \"cocowholebody\":\n            skeleton_list = [\n                [6, 8],\n                [8, 10],\n                [5, 7],\n                [7, 9],\n                [11, 13],\n                [13, 15],\n                [12, 14],\n                [14, 16],\n                [5, 6],\n                [6, 12],\n                [12, 11],\n                [11, 5],\n                [15, 17],\n                [15, 18],\n                [15, 19],\n                [16, 20],\n                [16, 21],\n                [16, 22],\n                [91, 92, 93, 94, 95],\n                [91, 96, 97, 98, 99],\n                [91, 100, 101, 102, 103],\n                [91, 104, 105, 106, 107],\n                [91, 108, 109, 110, 111],\n                [112, 113, 114, 115, 116],\n                [112, 117, 118, 119, 120],\n                [112, 121, 122, 123, 124],\n                [112, 125, 126, 127, 128],\n                [112, 129, 130, 131, 132],\n            ]\n            color_list = [\n                (255, 0, 0),\n                (0, 255, 0),\n                (0, 0, 255),\n                (255, 255, 0),\n                (255, 0, 255),\n                (0, 255, 255),\n            ]\n        else:\n            color_list = [color]\n        for _idx, _skeleton in enumerate(skeleton_list):\n            for i in range(len(_skeleton) - 1):\n                cv2.line(\n                    img,\n                    (int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])),\n                    (int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])),\n                    color_list[_idx % len(color_list)],\n                    3,\n                )\n\n    for _idx, kp2d in enumerate(kp2ds):\n        if kp2d[2] > threshold:\n            cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1)\n            # cv2.putText(img,\n            #         str(_idx),\n            #         (int(kp2d[0, i, 0])*1,\n            #             int(kp2d[0, i, 1])*1),\n            #         cv2.FONT_HERSHEY_SIMPLEX,\n            #         0.75,\n            #         color,\n            #         2\n            #         )\n\n    return img\n\n\ndef draw_mask(img, mask, background=0, return_rgba=False):\n    img = load_image(img)\n    h, w, _ = img.shape\n    if type(background) == int:\n        background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background\n    backgournd = cv2.resize(background, (w, h))\n    img_rgba = np.concatenate([img, mask], -1)\n    return alphaMerge(img_rgba, background, 0, 0, return_rgba=True)\n\n\ndef draw_pcd(pcd_list, save_path=None):\n    fig = plt.figure()\n    ax = fig.add_subplot(111, projection=\"3d\")\n\n    color_list = [\"r\", \"g\", \"b\", \"y\", \"p\"]\n\n    for _idx, _pcd in enumerate(pcd_list):\n        ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker=\"o\")\n\n    ax.set_xlabel(\"X\")\n    ax.set_ylabel(\"Y\")\n    ax.set_zlabel(\"Z\")\n\n    if save_path is not None:\n        plt.savefig(save_path)\n    else:\n        plt.savefig(\"tmp.png\")\n\n\ndef load_image(img, reverse=False):\n    if type(img) == str:\n        img = cv2.imread(img)\n    if reverse:\n        img = img.astype(np.float32)\n        img = img[:, :, ::-1]\n        img = img.astype(np.uint8)\n    return img\n\n\ndef draw_skeleten(meta):\n    kps = []\n    for i, kp in enumerate(meta[\"keypoints_body\"]):\n        if kp is None:\n            # if kp is None:\n            kps.append([0, 0, 0])\n        else:\n            kps.append([*kp, 1])\n    kps = np.array(kps)\n\n    kps[:, 0] *= meta[\"width\"]\n    kps[:, 1] *= meta[\"height\"]\n    pose_img = np.zeros([meta[\"height\"], meta[\"width\"], 3], dtype=np.uint8)\n\n    pose_img = draw_aapose(\n        pose_img,\n        kps,\n        draw_hand=True,\n        kp2ds_lhand=meta[\"keypoints_left_hand\"],\n        kp2ds_rhand=meta[\"keypoints_right_hand\"],\n    )\n    return pose_img\n\n\ndef draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray:\n    \"\"\"\n    Args:\n        pncc: [H,W,3]\n        meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand\n    Return:\n        np.ndarray [H, W, 3]\n    \"\"\"\n    # preprocess keypoints\n    kps = []\n    for i, kp in enumerate(meta[\"keypoints_body\"]):\n        if kp is None:\n            # if kp is None:\n            kps.append([0, 0, 0])\n        elif i in [14, 15, 16, 17]:\n            kps.append([0, 0, 0])\n        else:\n            kps.append([*kp])\n    kps = np.stack(kps)\n\n    kps[:, 0] *= pncc.shape[1]\n    kps[:, 1] *= pncc.shape[0]\n\n    # draw neck\n    canvas = np.zeros_like(pncc)\n    if kps[0][2] > 0.6 and kps[1][2] > 0.6:\n        canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255])\n\n    # draw pncc\n    mask = (pncc > 0).max(axis=2)\n    canvas[mask] = pncc[mask]\n    pncc = canvas\n\n    # draw other skeleten\n    kps[0] = 0\n\n    meta[\"keypoints_left_hand\"][:, 0] *= meta[\"width\"]\n    meta[\"keypoints_left_hand\"][:, 1] *= meta[\"height\"]\n\n    meta[\"keypoints_right_hand\"][:, 0] *= meta[\"width\"]\n    meta[\"keypoints_right_hand\"][:, 1] *= meta[\"height\"]\n    pose_img = draw_aapose(\n        pncc,\n        kps,\n        draw_hand=True,\n        kp2ds_lhand=meta[\"keypoints_left_hand\"],\n        kp2ds_rhand=meta[\"keypoints_right_hand\"],\n    )\n    return pose_img\n\n\nFACE_CUSTOM_STYLE = {\n    \"eyeball\": {\"indexs\": [68, 69], \"color\": [255, 255, 255], \"connect\": False},\n    \"left_eyebrow\": {\"indexs\": [17, 18, 19, 20, 21], \"color\": [0, 255, 0]},\n    \"right_eyebrow\": {\"indexs\": [22, 23, 24, 25, 26], \"color\": [0, 0, 255]},\n    \"left_eye\": {\"indexs\": [36, 37, 38, 39, 40, 41], \"color\": [255, 255, 0], \"close\": True},\n    \"right_eye\": {\"indexs\": [42, 43, 44, 45, 46, 47], \"color\": [255, 0, 255], \"close\": True},\n    \"mouth_outside\": {\"indexs\": list(range(48, 60)), \"color\": [100, 255, 50], \"close\": True},\n    \"mouth_inside\": {\"indexs\": [60, 61, 62, 63, 64, 65, 66, 67], \"color\": [255, 100, 50], \"close\": True},\n}\n\n\ndef draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE):\n    \"\"\"\n    Args:\n        img: [H, W, 3]\n        kps: [70, 2]\n    \"\"\"\n    img = img.copy()\n    for key, item in style.items():\n        pts = np.array(kps[item[\"indexs\"]]).astype(np.int32)\n        connect = item.get(\"connect\", True)\n        color = item[\"color\"]\n        close = item.get(\"close\", False)\n        if connect:\n            cv2.polylines(img, [pts], close, color, thickness=thickness)\n        else:\n            for kp in pts:\n                kp = np.array(kp).astype(np.int32)\n                cv2.circle(img, kp, thickness * 2, color=color, thickness=-1)\n    return img\n\n\ndef draw_traj(metas: List[AAPoseMeta], threshold=0.6):\n\n    colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \\\n                [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \\\n                [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [100, 255, 50], [255, 100, 50], \n                # foot\n                [200, 200, 0],\n                [100, 100, 0]\n                ]\n    limbSeq = [\n                    [1, 2], [1, 5],     # shoulders\n                    [2, 3], [3, 4],     # left arm\n                    [5, 6], [6, 7],     # right arm\n                    [1, 8], [8, 9], [9, 10],    # right leg \n                    [1, 11], [11, 12], [12, 13],  # left leg\n                     # face (nose, eyes, ears)\n                    [13, 18], [10, 19] # foot\n                ]\n    \n    face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]] \n    kp_body = np.array([meta.kps_body for meta in metas])\n    kp_body_p = np.array([meta.kps_body_p for meta in metas])\n    \n    \n    face_seq = random.sample(face_seq, 2)\n\n    kp_lh = np.array([meta.kps_lhand for meta in metas])\n    kp_rh = np.array([meta.kps_rhand for meta in metas])\n\n    kp_lh_p = np.array([meta.kps_lhand_p for meta in metas])\n    kp_rh_p = np.array([meta.kps_rhand_p for meta in metas])\n\n    # kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1)\n    # kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1)\n    \n    new_limbSeq = []\n    key_point_list = []\n    for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):\n        \n        vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1\n        if vis.sum() * 1.0 / vis.shape[0] > 0.4:\n            new_limbSeq.append([k1_index, k2_index])\n\n    for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):\n\n        keypoint1 = kp_body[:, k1_index - 1]\n        keypoint2 = kp_body[:, k2_index - 1]\n        interleave = random.randint(4, 7)\n        randind = random.randint(0, interleave - 1)\n        # randind = random.rand(range(interleave), sampling_num)\n\n        Y = np.array([keypoint1[:, 0], keypoint2[:, 0]]) \n        X = np.array([keypoint1[:, 1], keypoint2[:, 1]])\n\n        vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1\n\n        # for randidx in randind:\n        t = randind / interleave\n        x = (1-t)*Y[0, :] + t*Y[1, :]\n        y = (1-t)*X[0, :] + t*X[1, :]\n\n        # np.array([1])\n        x = x.astype(int)\n        y = y.astype(int)\n\n        new_array = np.array([x, y, vis]).T\n        \n        key_point_list.append(new_array)\n    \n    indx_lh = random.randint(0, kp_lh.shape[1] - 1)\n    lh = kp_lh[:, indx_lh, :]\n    lh_p = kp_lh_p[:, indx_lh:indx_lh+1]\n    lh = np.concatenate([lh, lh_p], axis=-1)\n    \n    indx_rh = random.randint(0, kp_rh.shape[1] - 1)\n    rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :]\n    rh_p = kp_rh_p[:, indx_rh:indx_rh+1]\n    rh = np.concatenate([rh, rh_p], axis=-1)\n\n\n\n    lh[-1, :] = (lh[-1, :] > threshold) * 1\n    rh[-1, :] = (rh[-1, :] > threshold) * 1\n\n    # print(rh.shape, new_array.shape)\n    # exit()\n    key_point_list.append(lh.astype(int))\n    key_point_list.append(rh.astype(int))\n\n    \n    key_points_list = np.stack(key_point_list)\n    num_points = len(key_points_list)\n    sample_colors = random.sample(colors, num_points)\n\n    stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2)\n\n    image_list_ori = []\n    for i in range(key_points_list.shape[-2]):\n        _image_vis = np.zeros((metas[0].width, metas[0].height, 3))\n        points = key_points_list[:, i, :]\n        for idx, point in enumerate(points):\n            x, y, vis = point\n            if vis == 1:\n                cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1)\n        \n        image_list_ori.append(_image_vis)\n    \n    return image_list_ori\n\n    return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas]\n\n\nif __name__ == \"__main__\":\n    meta = {\n        \"image_id\": \"00472.jpg\",\n        \"height\": 540,\n        \"width\": 414,\n        \"category_id\": 1,\n        \"keypoints_body\": [\n            [0.5084776947463768, 0.11350188078703703],\n            [0.504467655495169, 0.20419560185185184],\n            [0.3982016153381642, 0.198046875],\n            [0.3841664779589372, 0.34869068287037036],\n            [0.3901815368357488, 0.4670536747685185],\n            [0.610733695652174, 0.2103443287037037],\n            [0.6167487545289855, 0.3517650462962963],\n            [0.6448190292874396, 0.4762767650462963],\n            [0.4523371452294686, 0.47320240162037036],\n            [0.4503321256038647, 0.6776475694444445],\n            [0.47639738073671495, 0.8544234664351852],\n            [0.5766483620169082, 0.47320240162037036],\n            [0.5666232638888888, 0.6761103877314815],\n            [0.534542949879227, 0.863646556712963],\n            [0.4864224788647343, 0.09505570023148148],\n            [0.5285278910024155, 0.09351851851851851],\n            [0.46236224335748793, 0.10581597222222222],\n            [0.5586031853864735, 0.10274160879629629],\n            [0.4994551064311594, 0.9405056423611111],\n            [0.4152442821557971, 0.9312825520833333],\n        ],\n        \"keypoints_left_hand\": [\n            [267.78515625, 263.830078125, 1.2840936183929443],\n            [265.294921875, 269.640625, 1.2546794414520264],\n            [263.634765625, 277.111328125, 1.2863062620162964],\n            [262.8046875, 285.412109375, 1.267038345336914],\n            [261.14453125, 292.8828125, 1.280144453048706],\n            [273.595703125, 281.26171875, 1.2592815160751343],\n            [271.10546875, 291.22265625, 1.3256099224090576],\n            [265.294921875, 294.54296875, 1.2368024587631226],\n            [261.14453125, 294.54296875, 0.9771889448165894],\n            [274.42578125, 282.091796875, 1.250044584274292],\n            [269.4453125, 291.22265625, 1.2571144104003906],\n            [264.46484375, 292.8828125, 1.177802324295044],\n            [260.314453125, 292.052734375, 0.9283463358879089],\n            [273.595703125, 282.091796875, 1.1834490299224854],\n            [269.4453125, 290.392578125, 1.188171625137329],\n            [265.294921875, 290.392578125, 1.192609429359436],\n            [261.974609375, 289.5625, 0.9366656541824341],\n            [271.935546875, 281.26171875, 1.0946396589279175],\n            [268.615234375, 287.072265625, 0.9906131029129028],\n            [265.294921875, 287.90234375, 1.0219476222991943],\n            [262.8046875, 287.072265625, 0.9240120053291321],\n        ],\n        \"keypoints_right_hand\": [\n            [161.53515625, 258.849609375, 1.2069408893585205],\n            [168.17578125, 263.0, 1.1846840381622314],\n            [173.986328125, 269.640625, 1.1435924768447876],\n            [173.986328125, 277.94140625, 1.1802611351013184],\n            [173.986328125, 286.2421875, 1.2599592208862305],\n            [165.685546875, 275.451171875, 1.0633569955825806],\n            [167.345703125, 286.2421875, 1.1693341732025146],\n            [169.8359375, 291.22265625, 1.2698509693145752],\n            [170.666015625, 294.54296875, 1.0619274377822876],\n            [160.705078125, 276.28125, 1.0995020866394043],\n            [163.1953125, 287.90234375, 1.2735884189605713],\n            [166.515625, 291.22265625, 1.339503526687622],\n            [169.005859375, 294.54296875, 1.0835273265838623],\n            [157.384765625, 277.111328125, 1.0866981744766235],\n            [161.53515625, 287.072265625, 1.2468621730804443],\n            [164.025390625, 289.5625, 1.2817761898040771],\n            [166.515625, 292.052734375, 1.099466323852539],\n            [155.724609375, 277.111328125, 1.1065717935562134],\n            [159.044921875, 285.412109375, 1.1924479007720947],\n            [160.705078125, 287.072265625, 1.1304771900177002],\n            [162.365234375, 287.90234375, 1.0040509700775146],\n        ],\n    }\n    demo_meta = AAPoseMeta(meta)\n    res = draw_traj([demo_meta]*5)\n    cv2.imwrite(\"traj.png\", res[0][..., ::-1])\n"
  },
  {
    "path": "wan/modules/animate/preprocess/pose2d.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nfrom typing import Union, List\n\nimport numpy as np\nimport torch\nimport onnxruntime\n\nfrom pose2d_utils import (\n    read_img,\n    box_convert_simple,\n    bbox_from_detector,\n    crop,\n    keypoints_from_heatmaps,\n    load_pose_metas_from_kp2ds_seq\n)\n\n\nclass SimpleOnnxInference(object):\n    def __init__(self, checkpoint, device='cuda', reverse_input=False, **kwargs):\n        if isinstance(device, str):\n            device = torch.device(device)\n        if device.type == 'cuda':\n            device = '{}:{}'.format(device.type, device.index)\n            providers = [(\"CUDAExecutionProvider\", {\"device_id\": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else \"0\"}), \"CPUExecutionProvider\"]\n        else:\n            providers = [\"CPUExecutionProvider\"]\n        self.device = device\n        if not os.path.exists(checkpoint):\n            raise RuntimeError(\"{} is not existed!\".format(checkpoint))\n        \n        if os.path.isdir(checkpoint):\n            checkpoint = os.path.join(checkpoint, 'end2end.onnx')\n\n        self.session = onnxruntime.InferenceSession(checkpoint,\n                                                    providers=providers\n                                                    )\n        self.input_name = self.session.get_inputs()[0].name\n        self.output_name = self.session.get_outputs()[0].name\n        self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1]\n        self.input_resolution = np.array(self.input_resolution)\n        \n\n    def __call__(self, *args, **kwargs):\n        return self.forward(*args, **kwargs)\n    \n\n    def get_output_names(self):\n        output_names = []\n        for node in self.session.get_outputs():\n            output_names.append(node.name)\n        return output_names\n\n\n    def set_device(self, device):\n        if isinstance(device, str):\n            device = torch.device(device)\n        if device.type == 'cuda':\n            device = '{}:{}'.format(device.type, device.index)\n            providers = [(\"CUDAExecutionProvider\", {\"device_id\": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else \"0\"}), \"CPUExecutionProvider\"]\n        else:\n            providers = [\"CPUExecutionProvider\"]\n        self.session.set_providers(providers)\n        self.device = device\n\n\nclass Yolo(SimpleOnnxInference):\n    def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, threshold_multi_persons=0.1, input_resolution=(640, 640), threshold_iou=0.5, threshold_bbox_shape_ratio=0.4, cat_id=[1], select_type='max', strict=True, sorted_func=None, **kwargs):\n        super(Yolo, self).__init__(checkpoint, device=device, **kwargs)\n        \n        model_inputs = self.session.get_inputs()\n        input_shape = model_inputs[0].shape\n\n        self.input_width = 640\n        self.input_height = 640\n        \n        self.threshold_multi_persons = threshold_multi_persons\n        self.threshold_conf = threshold_conf\n        self.threshold_iou = threshold_iou\n        self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio\n        self.input_resolution = input_resolution\n        self.cat_id = cat_id\n        self.select_type = select_type\n        self.strict = strict\n        self.sorted_func = sorted_func\n        \n        \n    def preprocess(self, input_image):\n        \"\"\"\n        Preprocesses the input image before performing inference.\n\n        Returns:\n            image_data: Preprocessed image data ready for inference.\n        \"\"\"\n        img = read_img(input_image)\n        # Get the height and width of the input image\n        img_height, img_width = img.shape[:2]\n        # Resize the image to match the input shape\n        img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0]))\n        # Normalize the image data by dividing it by 255.0\n        image_data = np.array(img) / 255.0\n        # Transpose the image to have the channel dimension as the first dimension\n        image_data = np.transpose(image_data, (2, 0, 1))  # Channel first\n        # Expand the dimensions of the image data to match the expected input shape\n        # image_data = np.expand_dims(image_data, axis=0).astype(np.float32)\n        image_data = image_data.astype(np.float32)\n        # Return the preprocessed image data\n        return image_data, np.array([img_height, img_width])\n\n    \n    def postprocess(self, output, shape_raw, cat_id=[1]):\n        \"\"\"\n        Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.\n\n        Args:\n            input_image (numpy.ndarray): The input image.\n            output (numpy.ndarray): The output of the model.\n\n        Returns:\n            numpy.ndarray: The input image with detections drawn on it.\n        \"\"\"\n        # Transpose and squeeze the output to match the expected shape\n\n        outputs = np.squeeze(output)\n        if len(outputs.shape) == 1:\n            outputs = outputs[None]\n        if output.shape[-1] != 6 and output.shape[1] == 84:\n            outputs = np.transpose(outputs)\n        \n        # Get the number of rows in the outputs array\n        rows = outputs.shape[0]\n\n        # Calculate the scaling factors for the bounding box coordinates\n        x_factor = shape_raw[1] / self.input_width\n        y_factor = shape_raw[0] / self.input_height\n\n        # Lists to store the bounding boxes, scores, and class IDs of the detections\n        boxes = []\n        scores = []\n        class_ids = []\n\n        if outputs.shape[-1] == 6:\n            max_scores = outputs[:, 4]\n            classid = outputs[:, -1]\n            \n            threshold_conf_masks = max_scores >= self.threshold_conf\n            classid_masks = classid[threshold_conf_masks] != 3.14159\n\n            max_scores = max_scores[threshold_conf_masks][classid_masks]\n            classid = classid[threshold_conf_masks][classid_masks]\n\n            boxes = outputs[:, :4][threshold_conf_masks][classid_masks]\n            boxes[:, [0, 2]] *= x_factor\n            boxes[:, [1, 3]] *= y_factor\n            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]\n            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]\n            boxes = boxes.astype(np.int32)\n\n        else:\n            classes_scores = outputs[:, 4:]\n            max_scores = np.amax(classes_scores, -1)\n            threshold_conf_masks = max_scores >= self.threshold_conf\n\n            classid = np.argmax(classes_scores[threshold_conf_masks], -1)\n\n            classid_masks = classid!=3.14159\n            \n            classes_scores = classes_scores[threshold_conf_masks][classid_masks]\n            max_scores = max_scores[threshold_conf_masks][classid_masks]\n            classid = classid[classid_masks]\n    \n            xywh = outputs[:, :4][threshold_conf_masks][classid_masks]\n\n            x = xywh[:, 0:1]\n            y = xywh[:, 1:2]\n            w = xywh[:, 2:3]\n            h = xywh[:, 3:4]\n    \n            left = ((x - w / 2) * x_factor)\n            top = ((y - h / 2) * y_factor)\n            width = (w * x_factor)\n            height = (h * y_factor)\n            boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32)\n\n        boxes = boxes.tolist()\n        scores = max_scores.tolist()\n        class_ids = classid.tolist()\n\n        # Apply non-maximum suppression to filter out overlapping bounding boxes\n        indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou)\n        # Iterate over the selected indices after non-maximum suppression\n        \n        results = []\n        for i in indices:\n            # Get the box, score, and class ID corresponding to the index\n            box = box_convert_simple(boxes[i], 'xywh2xyxy')\n            score = scores[i]\n            class_id = class_ids[i]\n            results.append(box + [score] + [class_id])\n            # # Draw the detection on the input image\n\n        # Return the modified input image\n        return np.array(results)\n\n    \n    def process_results(self, results, shape_raw, cat_id=[1], single_person=True):\n        if isinstance(results, tuple):\n            det_results = results[0]\n        else:\n            det_results = results\n\n        person_results = []\n        person_count = 0\n        if len(results):\n            max_idx = -1\n            max_bbox_size = shape_raw[0] * shape_raw[1] * -10\n            max_bbox_shape = -1\n            \n            bboxes = []\n            idx_list = []\n            for i in range(results.shape[0]):\n                bbox = results[i]\n                if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):\n                    idx_list.append(i)\n                    bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))\n                    if bbox_shape > max_bbox_shape:\n                        max_bbox_shape = bbox_shape\n            \n            results = results[idx_list]\n\n            for i in range(results.shape[0]):\n                bbox = results[i]\n                bboxes.append(bbox)\n                if self.select_type == 'max':\n                    bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))\n                elif self.select_type == 'center':\n                    bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1\n                bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))\n                if bbox_size > max_bbox_size:\n                    if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio:\n                        continue\n                    max_bbox_size = bbox_size\n                    max_bbox_shape = bbox_shape\n                    max_idx = i\n\n            if self.sorted_func is not None and len(bboxes) > 0:\n                max_idx = self.sorted_func(bboxes, shape_raw)\n                bbox = bboxes[max_idx]\n                if self.select_type == 'max':\n                    max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))\n                elif self.select_type == 'center':\n                    max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1\n                \n            if max_idx != -1:\n                person_count = 1\n\n            if max_idx != -1:\n                person = {}\n                person['bbox'] = results[max_idx, :5]\n                person['track_id'] = int(0)\n                person_results.append(person)\n\n            for i in range(results.shape[0]):\n                bbox = results[i]\n                if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):\n                    if self.select_type == 'max':\n                        bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))\n                    elif self.select_type == 'center':\n                        bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1\n                    if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size:\n                        person_count += 1\n                        if not single_person:\n                            person = {}\n                            person['bbox'] = results[i, :5]\n                            person['track_id'] = int(person_count - 1)\n                            person_results.append(person)                   \n            return person_results\n        else:\n            return None\n        \n\n    def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs):\n        result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id)\n        result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person)\n        if result is not None and len(result) != 0:\n            person_results[i] = result\n\n\n    def forward(self, img, shape_raw, **kwargs):\n        \"\"\"\n        Performs inference using an ONNX model and returns the output image with drawn detections.\n\n        Returns:\n            output_img: The output image with drawn detections.\n        \"\"\"\n        if isinstance(img, torch.Tensor):\n            img = img.cpu().numpy()\n            shape_raw = shape_raw.cpu().numpy()\n\n        outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0]\n        person_results = [[{'bbox': np.array([0., 0., 1.*shape_raw[i][1], 1.*shape_raw[i][0], -1]), 'track_id': -1}] for i in range(len(outputs))]\n\n        for i in range(len(outputs)):\n            self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs)         \n        return person_results\n\n\nclass ViTPose(SimpleOnnxInference):\n    def __init__(self, checkpoint, device='cuda', **kwargs):\n        super(ViTPose, self).__init__(checkpoint, device=device)\n\n    def forward(self, img, center, scale, **kwargs):\n        heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0]\n        points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,\n                                            center=center,\n                                            scale=scale*200,\n                                            unbiased=True, \n                                            use_udp=False)\n        return np.concatenate([points, prob], axis=2)\n\n\n    @staticmethod\n    def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs):\n        if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:\n            bbox = np.array([0, 0, img.shape[1], img.shape[0]])\n        \n        bbox_xywh = bbox\n        if mask is not None:\n            img = np.where(mask>128, img, mask)\n\n        if isinstance(input_resolution, int):\n            center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale)\n            img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution))\n        else:\n            center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)\n            img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1]))\n\n        IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])\n        IMG_NORM_STD = np.array([0.229, 0.224, 0.225])\n        img_norm = (img / 255. - IMG_NORM_MEAN) / IMG_NORM_STD\n        img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)\n        return img_norm, np.array(center), np.array(scale)\n\n\nclass Pose2d:\n    def __init__(self, checkpoint, detector_checkpoint=None, device='cuda', **kwargs):\n\n        if detector_checkpoint is not None:\n            self.detector = Yolo(detector_checkpoint, device)\n        else:\n            self.detector = None\n\n        self.model = ViTPose(checkpoint, device)\n        self.device = device\n\n    def load_images(self, inputs):\n        \"\"\"\n        Load images from various input types.\n        \n        Args:\n            inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path, \n                     single image array, or list of image arrays\n            \n        Returns:\n            List[np.ndarray]: List of RGB image arrays\n            \n        Raises:\n            ValueError: If file format is unsupported or image cannot be read\n        \"\"\"\n        if isinstance(inputs, str):\n            if inputs.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):\n                cap = cv2.VideoCapture(inputs)\n                frames = []\n                while True:\n                    ret, frame = cap.read()\n                    if not ret:\n                        break\n                    frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n                cap.release()\n                images = frames\n            elif inputs.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):\n                img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB)\n                if img is None:\n                    raise ValueError(f\"Cannot read image: {inputs}\")\n                images = [img]\n            else:\n                raise ValueError(f\"Unsupported file format: {inputs}\")\n                \n        elif isinstance(inputs, np.ndarray):\n            images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]\n        elif isinstance(inputs, list):\n            images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]\n        return images\n\n    def __call__(\n        self, \n        inputs: Union[str, np.ndarray, List[np.ndarray]],\n        return_image: bool = False,\n        **kwargs\n    ):\n        \"\"\"\n        Process input and estimate 2D keypoints.\n        \n        Args:\n            inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,\n                     single image array, or list of image arrays\n            **kwargs: Additional arguments for processing\n            \n        Returns:\n            np.ndarray: Array of detected 2D keypoints for all input images\n        \"\"\"\n        images = self.load_images(inputs)\n        H, W = images[0].shape[:2]\n        if self.detector is not None:\n            bboxes = []\n            for _image in images:\n                img, shape = self.detector.preprocess(_image)\n                bboxes.append(self.detector(img[None], shape[None])[0][0][\"bbox\"])\n        else:\n            bboxes = [None] * len(images)\n\n        kp2ds = []\n        for _image, _bbox in zip(images, bboxes):\n            img, center, scale = self.model.preprocess(_image, _bbox)\n            kp2ds.append(self.model(img[None], center[None], scale[None]))\n        kp2ds = np.concatenate(kp2ds, 0)\n        metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)\n        return metas"
  },
  {
    "path": "wan/modules/animate/preprocess/pose2d_utils.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport warnings\nimport cv2\nimport numpy as np\nfrom typing import List\nfrom PIL import Image\n\n\ndef box_convert_simple(box, convert_type='xyxy2xywh'):\n    if convert_type == 'xyxy2xywh':\n        return [box[0], box[1], box[2] - box[0], box[3] - box[1]]\n    elif convert_type == 'xywh2xyxy':\n        return [box[0], box[1], box[2] + box[0], box[3] + box[1]]\n    elif convert_type == 'xyxy2ctwh':\n        return [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2, box[2] - box[0], box[3] - box[1]]\n    elif convert_type == 'ctwh2xyxy':\n        return [box[0] - box[2] // 2, box[1] - box[3] // 2, box[0] + (box[2] - box[2] // 2), box[1] + (box[3] - box[3] // 2)]\n\ndef read_img(image, convert='RGB', check_exist=False):\n    if isinstance(image, str):\n        if check_exist and not osp.exists(image):\n            return None\n        try:\n            img = Image.open(image)\n            if convert:\n                img = img.convert(convert)\n        except:\n            raise IOError('File error: ', image)\n        return np.asarray(img)\n    else:\n        if isinstance(image, np.ndarray):\n            if convert:\n                return image[..., ::-1]\n        else:\n            if convert:\n                img = img.convert(convert)\n            return np.asarray(img)\n\nclass AAPoseMeta:\n    def __init__(self, meta=None, kp2ds=None):\n        self.image_id = \"\"\n        self.height = 0\n        self.width = 0\n\n        self.kps_body: np.ndarray = None\n        self.kps_lhand: np.ndarray = None\n        self.kps_rhand: np.ndarray = None\n        self.kps_face: np.ndarray = None\n        self.kps_body_p: np.ndarray = None\n        self.kps_lhand_p: np.ndarray = None\n        self.kps_rhand_p: np.ndarray = None\n        self.kps_face_p: np.ndarray = None\n\n\n        if meta is not None:\n            self.load_from_meta(meta)\n        elif kp2ds is not None:\n            self.load_from_kp2ds(kp2ds)\n    \n    def is_valid(self, kp, p, threshold):\n        x, y = kp\n        if x < 0 or y < 0 or x > self.width or y > self.height or p < threshold: \n            return False\n        else:\n            return True\n    \n    def get_bbox(self, kp, kp_p, threshold=0.5):\n        kps = kp[kp_p > threshold]\n        if kps.size == 0:\n            return 0, 0, 0, 0\n        x0, y0 = kps.min(axis=0)\n        x1, y1 = kps.max(axis=0)\n        return x0, y0, x1, y1\n    \n    def crop(self, x0, y0, x1, y1):\n        all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]\n        for kps in all_kps:\n            if kps is not None:\n                kps[:, 0] -= x0\n                kps[:, 1] -= y0\n        self.width = x1 - x0\n        self.height = y1 - y0\n        return self\n    \n    def resize(self, width, height):\n        scale_x = width / self.width\n        scale_y = height / self.height\n        all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]\n        for kps in all_kps:\n            if kps is not None:\n                kps[:, 0] *= scale_x\n                kps[:, 1] *= scale_y\n        self.width = width\n        self.height = height\n        return self\n\n    \n    def get_kps_body_with_p(self, normalize=False):\n        kps_body = self.kps_body.copy()\n        if normalize:\n            kps_body = kps_body / np.array([self.width, self.height])\n\n        return np.concatenate([kps_body, self.kps_body_p[:, None]])\n    \n    @staticmethod\n    def from_kps_face(kps_face: np.ndarray, height: int, width: int):\n\n        pose_meta = AAPoseMeta()\n        pose_meta.kps_face = kps_face[:, :2]\n        if kps_face.shape[1] == 3:\n            pose_meta.kps_face_p = kps_face[:, 2]\n        else:\n            pose_meta.kps_face_p = kps_face[:, 0] * 0 + 1\n        pose_meta.height = height\n        pose_meta.width = width\n        return pose_meta\n\n    @staticmethod\n    def from_kps_body(kps_body: np.ndarray, height: int, width: int):\n\n        pose_meta = AAPoseMeta()\n        pose_meta.kps_body = kps_body[:, :2]\n        pose_meta.kps_body_p = kps_body[:, 2]\n        pose_meta.height = height\n        pose_meta.width = width\n        return pose_meta\n    @staticmethod\n    def from_humanapi_meta(meta):\n        pose_meta = AAPoseMeta()\n        width, height = meta[\"width\"], meta[\"height\"]\n        pose_meta.width = width\n        pose_meta.height = height\n        pose_meta.kps_body = meta[\"keypoints_body\"][:, :2] * (width, height)\n        pose_meta.kps_body_p = meta[\"keypoints_body\"][:, 2]\n        pose_meta.kps_lhand = meta[\"keypoints_left_hand\"][:, :2] * (width, height)\n        pose_meta.kps_lhand_p = meta[\"keypoints_left_hand\"][:, 2]\n        pose_meta.kps_rhand = meta[\"keypoints_right_hand\"][:, :2] * (width, height)\n        pose_meta.kps_rhand_p = meta[\"keypoints_right_hand\"][:, 2]\n        if 'keypoints_face' in meta:\n            pose_meta.kps_face = meta[\"keypoints_face\"][:, :2] * (width, height)\n            pose_meta.kps_face_p = meta[\"keypoints_face\"][:, 2]\n        return pose_meta\n    \n    def load_from_meta(self, meta, norm_body=True, norm_hand=False):\n        \n        self.image_id = meta.get(\"image_id\", \"00000.png\")\n        self.height = meta[\"height\"]\n        self.width = meta[\"width\"]\n        kps_body_p = []\n        kps_body = []\n        for kp in meta[\"keypoints_body\"]:\n            if kp is None:\n                kps_body.append([0, 0])\n                kps_body_p.append(0)\n            else:\n                kps_body.append(kp)\n                kps_body_p.append(1)\n\n        self.kps_body = np.array(kps_body)\n        self.kps_body[:, 0] *= self.width\n        self.kps_body[:, 1] *= self.height\n        self.kps_body_p = np.array(kps_body_p)\n\n        self.kps_lhand = np.array(meta[\"keypoints_left_hand\"])[:, :2]\n        self.kps_lhand_p = np.array(meta[\"keypoints_left_hand\"])[:, 2]\n        self.kps_rhand = np.array(meta[\"keypoints_right_hand\"])[:, :2]\n        self.kps_rhand_p = np.array(meta[\"keypoints_right_hand\"])[:, 2]\n\n    @staticmethod\n    def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int): \n        \"\"\"input 133x3 numpy keypoints and output AAPoseMeta\n\n        Args:\n            kp2ds (List[np.ndarray]): _description_\n            width (int): _description_\n            height (int): _description_\n\n        Returns:\n            _type_: _description_\n        \"\"\"\n        pose_meta = AAPoseMeta()\n        pose_meta.width = width\n        pose_meta.height = height\n        kps_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2\n        kps_lhand = kp2ds[91:112]\n        kps_rhand = kp2ds[112:133]\n        kps_face = np.concatenate([kp2ds[23:23+68], kp2ds[1:3]], axis=0)\n        pose_meta.kps_body = kps_body[:, :2]\n        pose_meta.kps_body_p = kps_body[:, 2]\n        pose_meta.kps_lhand = kps_lhand[:, :2]\n        pose_meta.kps_lhand_p = kps_lhand[:, 2]\n        pose_meta.kps_rhand = kps_rhand[:, :2]\n        pose_meta.kps_rhand_p = kps_rhand[:, 2]\n        pose_meta.kps_face = kps_face[:, :2]\n        pose_meta.kps_face_p = kps_face[:, 2]\n        return pose_meta\n    \n    @staticmethod\n    def from_dwpose(dwpose_det_res, height, width):\n        pose_meta = AAPoseMeta()\n        pose_meta.kps_body = dwpose_det_res[\"bodies\"][\"candidate\"]\n        pose_meta.kps_body_p = dwpose_det_res[\"bodies\"][\"score\"]\n        pose_meta.kps_body[:, 0] *= width\n        pose_meta.kps_body[:, 1] *= height\n\n        pose_meta.kps_lhand, pose_meta.kps_rhand = dwpose_det_res[\"hands\"]\n        pose_meta.kps_lhand[:, 0] *= width\n        pose_meta.kps_lhand[:, 1] *= height\n        pose_meta.kps_rhand[:, 0] *= width\n        pose_meta.kps_rhand[:, 1] *= height\n        pose_meta.kps_lhand_p, pose_meta.kps_rhand_p = dwpose_det_res[\"hands_score\"]\n\n        pose_meta.kps_face = dwpose_det_res[\"faces\"][0]\n        pose_meta.kps_face[:, 0] *= width\n        pose_meta.kps_face[:, 1] *= height\n        pose_meta.kps_face_p = dwpose_det_res[\"faces_score\"][0]\n        return pose_meta\n\n    def save_json(self):\n        pass\n\n    def draw_aapose(self, img, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):\n        from .human_visualization import draw_aapose_by_meta\n        return draw_aapose_by_meta(img, self, threshold, stick_width_norm, draw_hand, draw_head)\n\n    \n    def translate(self, x0, y0):\n        all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]\n        for kps in all_kps:\n            if kps is not None:\n                kps[:, 0] -= x0\n                kps[:, 1] -= y0\n\n    def scale(self, sx, sy):\n        all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]\n        for kps in all_kps:\n            if kps is not None:\n                kps[:, 0] *= sx\n                kps[:, 1] *= sy\n    \n    def padding_resize2(self, height=512, width=512):\n        \"\"\"kps will be changed inplace\n\n        \"\"\"\n\n        all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face]\n\n        ori_height, ori_width = self.height, self.width\n\n        if (ori_height / ori_width) > (height / width):\n            new_width = int(height / ori_height * ori_width)\n            padding = int((width - new_width) / 2)\n            padding_width = padding\n            padding_height = 0\n            scale = height / ori_height \n\n            for kps in all_kps:\n                if kps is not None:\n                    kps[:, 0] = kps[:, 0] * scale + padding\n                    kps[:, 1] = kps[:, 1] * scale\n\n        else:\n            new_height = int(width / ori_width * ori_height)\n            padding = int((height - new_height) / 2)\n            padding_width = 0\n            padding_height = padding \n            scale = width / ori_width\n            for kps in all_kps:\n                if kps is not None:\n                    kps[:, 1] = kps[:, 1] * scale + padding\n                    kps[:, 0] = kps[:, 0] * scale\n\n\n        self.width = width\n        self.height = height\n        return self\n        \n\ndef transform_preds(coords, center, scale, output_size, use_udp=False):\n    \"\"\"Get final keypoint predictions from heatmaps and apply scaling and\n    translation to map them back to the image.\n\n    Note:\n        num_keypoints: K\n\n    Args:\n        coords (np.ndarray[K, ndims]):\n\n            * If ndims=2, corrds are predicted keypoint location.\n            * If ndims=4, corrds are composed of (x, y, scores, tags)\n            * If ndims=5, corrds are composed of (x, y, scores, tags,\n              flipped_tags)\n\n        center (np.ndarray[2, ]): Center of the bounding box (x, y).\n        scale (np.ndarray[2, ]): Scale of the bounding box\n            wrt [width, height].\n        output_size (np.ndarray[2, ] | list(2,)): Size of the\n            destination heatmaps.\n        use_udp (bool): Use unbiased data processing\n\n    Returns:\n        np.ndarray: Predicted coordinates in the images.\n    \"\"\"\n    assert coords.shape[1] in (2, 4, 5)\n    assert len(center) == 2\n    assert len(scale) == 2\n    assert len(output_size) == 2\n\n    # Recover the scale which is normalized by a factor of 200.\n    # scale = scale * 200.0\n\n    if use_udp:\n        scale_x = scale[0] / (output_size[0] - 1.0)\n        scale_y = scale[1] / (output_size[1] - 1.0)\n    else:\n        scale_x = scale[0] / output_size[0]\n        scale_y = scale[1] / output_size[1]\n\n    target_coords = np.ones_like(coords)\n    target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5\n    target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5\n\n    return target_coords\n\n\ndef _calc_distances(preds, targets, mask, normalize):\n    \"\"\"Calculate the normalized distances between preds and target.\n\n    Note:\n        batch_size: N\n        num_keypoints: K\n        dimension of keypoints: D (normally, D=2 or D=3)\n\n    Args:\n        preds (np.ndarray[N, K, D]): Predicted keypoint location.\n        targets (np.ndarray[N, K, D]): Groundtruth keypoint location.\n        mask (np.ndarray[N, K]): Visibility of the target. False for invisible\n            joints, and True for visible. Invisible joints will be ignored for\n            accuracy calculation.\n        normalize (np.ndarray[N, D]): Typical value is heatmap_size\n\n    Returns:\n        np.ndarray[K, N]: The normalized distances. \\\n            If target keypoints are missing, the distance is -1.\n    \"\"\"\n    N, K, _ = preds.shape\n    # set mask=0 when normalize==0\n    _mask = mask.copy()\n    _mask[np.where((normalize == 0).sum(1))[0], :] = False\n    distances = np.full((N, K), -1, dtype=np.float32)\n    # handle invalid values\n    normalize[np.where(normalize <= 0)] = 1e6\n    distances[_mask] = np.linalg.norm(\n        ((preds - targets) / normalize[:, None, :])[_mask], axis=-1)\n    return distances.T\n\n\ndef _distance_acc(distances, thr=0.5):\n    \"\"\"Return the percentage below the distance threshold, while ignoring\n    distances values with -1.\n\n    Note:\n        batch_size: N\n    Args:\n        distances (np.ndarray[N, ]): The normalized distances.\n        thr (float): Threshold of the distances.\n\n    Returns:\n        float: Percentage of distances below the threshold. \\\n            If all target keypoints are missing, return -1.\n    \"\"\"\n    distance_valid = distances != -1\n    num_distance_valid = distance_valid.sum()\n    if num_distance_valid > 0:\n        return (distances[distance_valid] < thr).sum() / num_distance_valid\n    return -1\n\n\ndef _get_max_preds(heatmaps):\n    \"\"\"Get keypoint predictions from score maps.\n\n    Note:\n        batch_size: N\n        num_keypoints: K\n        heatmap height: H\n        heatmap width: W\n\n    Args:\n        heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.\n\n    Returns:\n        tuple: A tuple containing aggregated results.\n\n        - preds (np.ndarray[N, K, 2]): Predicted keypoint location.\n        - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.\n    \"\"\"\n    assert isinstance(heatmaps,\n                      np.ndarray), ('heatmaps should be numpy.ndarray')\n    assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'\n\n    N, K, _, W = heatmaps.shape\n    heatmaps_reshaped = heatmaps.reshape((N, K, -1))\n    idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))\n    maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))\n\n    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)\n    preds[:, :, 0] = preds[:, :, 0] % W\n    preds[:, :, 1] = preds[:, :, 1] // W\n\n    preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1)\n    return preds, maxvals\n\n\ndef _get_max_preds_3d(heatmaps):\n    \"\"\"Get keypoint predictions from 3D score maps.\n\n    Note:\n        batch size: N\n        num keypoints: K\n        heatmap depth size: D\n        heatmap height: H\n        heatmap width: W\n\n    Args:\n        heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.\n\n    Returns:\n        tuple: A tuple containing aggregated results.\n\n        - preds (np.ndarray[N, K, 3]): Predicted keypoint location.\n        - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.\n    \"\"\"\n    assert isinstance(heatmaps, np.ndarray), \\\n        ('heatmaps should be numpy.ndarray')\n    assert heatmaps.ndim == 5, 'heatmaps should be 5-ndim'\n\n    N, K, D, H, W = heatmaps.shape\n    heatmaps_reshaped = heatmaps.reshape((N, K, -1))\n    idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1))\n    maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1))\n\n    preds = np.zeros((N, K, 3), dtype=np.float32)\n    _idx = idx[..., 0]\n    preds[..., 2] = _idx // (H * W)\n    preds[..., 1] = (_idx // W) % H\n    preds[..., 0] = _idx % W\n\n    preds = np.where(maxvals > 0.0, preds, -1)\n    return preds, maxvals\n\n\ndef pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None):\n    \"\"\"Calculate the pose accuracy of PCK for each individual keypoint and the\n    averaged accuracy across all keypoints from heatmaps.\n\n    Note:\n        PCK metric measures accuracy of the localization of the body joints.\n        The distances between predicted positions and the ground-truth ones\n        are typically normalized by the bounding box size.\n        The threshold (thr) of the normalized distance is commonly set\n        as 0.05, 0.1 or 0.2 etc.\n\n        - batch_size: N\n        - num_keypoints: K\n        - heatmap height: H\n        - heatmap width: W\n\n    Args:\n        output (np.ndarray[N, K, H, W]): Model output heatmaps.\n        target (np.ndarray[N, K, H, W]): Groundtruth heatmaps.\n        mask (np.ndarray[N, K]): Visibility of the target. False for invisible\n            joints, and True for visible. Invisible joints will be ignored for\n            accuracy calculation.\n        thr (float): Threshold of PCK calculation. Default 0.05.\n        normalize (np.ndarray[N, 2]): Normalization factor for H&W.\n\n    Returns:\n        tuple: A tuple containing keypoint accuracy.\n\n        - np.ndarray[K]: Accuracy of each keypoint.\n        - float: Averaged accuracy across all keypoints.\n        - int: Number of valid keypoints.\n    \"\"\"\n    N, K, H, W = output.shape\n    if K == 0:\n        return None, 0, 0\n    if normalize is None:\n        normalize = np.tile(np.array([[H, W]]), (N, 1))\n\n    pred, _ = _get_max_preds(output)\n    gt, _ = _get_max_preds(target)\n    return keypoint_pck_accuracy(pred, gt, mask, thr, normalize)\n\n\ndef keypoint_pck_accuracy(pred, gt, mask, thr, normalize):\n    \"\"\"Calculate the pose accuracy of PCK for each individual keypoint and the\n    averaged accuracy across all keypoints for coordinates.\n\n    Note:\n        PCK metric measures accuracy of the localization of the body joints.\n        The distances between predicted positions and the ground-truth ones\n        are typically normalized by the bounding box size.\n        The threshold (thr) of the normalized distance is commonly set\n        as 0.05, 0.1 or 0.2 etc.\n\n        - batch_size: N\n        - num_keypoints: K\n\n    Args:\n        pred (np.ndarray[N, K, 2]): Predicted keypoint location.\n        gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.\n        mask (np.ndarray[N, K]): Visibility of the target. False for invisible\n            joints, and True for visible. Invisible joints will be ignored for\n            accuracy calculation.\n        thr (float): Threshold of PCK calculation.\n        normalize (np.ndarray[N, 2]): Normalization factor for H&W.\n\n    Returns:\n        tuple: A tuple containing keypoint accuracy.\n\n        - acc (np.ndarray[K]): Accuracy of each keypoint.\n        - avg_acc (float): Averaged accuracy across all keypoints.\n        - cnt (int): Number of valid keypoints.\n    \"\"\"\n    distances = _calc_distances(pred, gt, mask, normalize)\n\n    acc = np.array([_distance_acc(d, thr) for d in distances])\n    valid_acc = acc[acc >= 0]\n    cnt = len(valid_acc)\n    avg_acc = valid_acc.mean() if cnt > 0 else 0\n    return acc, avg_acc, cnt\n\n\ndef keypoint_auc(pred, gt, mask, normalize, num_step=20):\n    \"\"\"Calculate the pose accuracy of PCK for each individual keypoint and the\n    averaged accuracy across all keypoints for coordinates.\n\n    Note:\n        - batch_size: N\n        - num_keypoints: K\n\n    Args:\n        pred (np.ndarray[N, K, 2]): Predicted keypoint location.\n        gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.\n        mask (np.ndarray[N, K]): Visibility of the target. False for invisible\n            joints, and True for visible. Invisible joints will be ignored for\n            accuracy calculation.\n        normalize (float): Normalization factor.\n\n    Returns:\n        float: Area under curve.\n    \"\"\"\n    nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1))\n    x = [1.0 * i / num_step for i in range(num_step)]\n    y = []\n    for thr in x:\n        _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)\n        y.append(avg_acc)\n\n    auc = 0\n    for i in range(num_step):\n        auc += 1.0 / num_step * y[i]\n    return auc\n\n\ndef keypoint_nme(pred, gt, mask, normalize_factor):\n    \"\"\"Calculate the normalized mean error (NME).\n\n    Note:\n        - batch_size: N\n        - num_keypoints: K\n\n    Args:\n        pred (np.ndarray[N, K, 2]): Predicted keypoint location.\n        gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.\n        mask (np.ndarray[N, K]): Visibility of the target. False for invisible\n            joints, and True for visible. Invisible joints will be ignored for\n            accuracy calculation.\n        normalize_factor (np.ndarray[N, 2]): Normalization factor.\n\n    Returns:\n        float: normalized mean error\n    \"\"\"\n    distances = _calc_distances(pred, gt, mask, normalize_factor)\n    distance_valid = distances[distances != -1]\n    return distance_valid.sum() / max(1, len(distance_valid))\n\n\ndef keypoint_epe(pred, gt, mask):\n    \"\"\"Calculate the end-point error.\n\n    Note:\n        - batch_size: N\n        - num_keypoints: K\n\n    Args:\n        pred (np.ndarray[N, K, 2]): Predicted keypoint location.\n        gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.\n        mask (np.ndarray[N, K]): Visibility of the target. False for invisible\n            joints, and True for visible. Invisible joints will be ignored for\n            accuracy calculation.\n\n    Returns:\n        float: Average end-point error.\n    \"\"\"\n\n    distances = _calc_distances(\n        pred, gt, mask,\n        np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))\n    distance_valid = distances[distances != -1]\n    return distance_valid.sum() / max(1, len(distance_valid))\n\n\ndef _taylor(heatmap, coord):\n    \"\"\"Distribution aware coordinate decoding method.\n\n    Note:\n        - heatmap height: H\n        - heatmap width: W\n\n    Args:\n        heatmap (np.ndarray[H, W]): Heatmap of a particular joint type.\n        coord (np.ndarray[2,]): Coordinates of the predicted keypoints.\n\n    Returns:\n        np.ndarray[2,]: Updated coordinates.\n    \"\"\"\n    H, W = heatmap.shape[:2]\n    px, py = int(coord[0]), int(coord[1])\n    if 1 < px < W - 2 and 1 < py < H - 2:\n        dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1])\n        dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px])\n        dxx = 0.25 * (\n            heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2])\n        dxy = 0.25 * (\n            heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] -\n            heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1])\n        dyy = 0.25 * (\n            heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] +\n            heatmap[py - 2 * 1][px])\n        derivative = np.array([[dx], [dy]])\n        hessian = np.array([[dxx, dxy], [dxy, dyy]])\n        if dxx * dyy - dxy**2 != 0:\n            hessianinv = np.linalg.inv(hessian)\n            offset = -hessianinv @ derivative\n            offset = np.squeeze(np.array(offset.T), axis=0)\n            coord += offset\n    return coord\n\n\ndef post_dark_udp(coords, batch_heatmaps, kernel=3):\n    \"\"\"DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The\n    Devil is in the Details: Delving into Unbiased Data Processing for Human\n    Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate\n    Representation for Human Pose Estimation (CVPR 2020).\n\n    Note:\n        - batch size: B\n        - num keypoints: K\n        - num persons: N\n        - height of heatmaps: H\n        - width of heatmaps: W\n\n        B=1 for bottom_up paradigm where all persons share the same heatmap.\n        B=N for top_down paradigm where each person has its own heatmaps.\n\n    Args:\n        coords (np.ndarray[N, K, 2]): Initial coordinates of human pose.\n        batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps\n        kernel (int): Gaussian kernel size (K) for modulation.\n\n    Returns:\n        np.ndarray([N, K, 2]): Refined coordinates.\n    \"\"\"\n    if not isinstance(batch_heatmaps, np.ndarray):\n        batch_heatmaps = batch_heatmaps.cpu().numpy()\n    B, K, H, W = batch_heatmaps.shape\n    N = coords.shape[0]\n    assert (B == 1 or B == N)\n    for heatmaps in batch_heatmaps:\n        for heatmap in heatmaps:\n            cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap)\n    np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps)\n    np.log(batch_heatmaps, batch_heatmaps)\n\n    batch_heatmaps_pad = np.pad(\n        batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)),\n        mode='edge').flatten()\n\n    index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2)\n    index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K)\n    index = index.astype(int).reshape(-1, 1)\n    i_ = batch_heatmaps_pad[index]\n    ix1 = batch_heatmaps_pad[index + 1]\n    iy1 = batch_heatmaps_pad[index + W + 2]\n    ix1y1 = batch_heatmaps_pad[index + W + 3]\n    ix1_y1_ = batch_heatmaps_pad[index - W - 3]\n    ix1_ = batch_heatmaps_pad[index - 1]\n    iy1_ = batch_heatmaps_pad[index - 2 - W]\n\n    dx = 0.5 * (ix1 - ix1_)\n    dy = 0.5 * (iy1 - iy1_)\n    derivative = np.concatenate([dx, dy], axis=1)\n    derivative = derivative.reshape(N, K, 2, 1)\n    dxx = ix1 - 2 * i_ + ix1_\n    dyy = iy1 - 2 * i_ + iy1_\n    dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)\n    hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)\n    hessian = hessian.reshape(N, K, 2, 2)\n    hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))\n    coords -= np.einsum('ijmn,ijnk->ijmk', hessian, derivative).squeeze()\n    return coords\n\n\ndef _gaussian_blur(heatmaps, kernel=11):\n    \"\"\"Modulate heatmap distribution with Gaussian.\n     sigma = 0.3*((kernel_size-1)*0.5-1)+0.8\n     sigma~=3 if k=17\n     sigma=2 if k=11;\n     sigma~=1.5 if k=7;\n     sigma~=1 if k=3;\n\n    Note:\n        - batch_size: N\n        - num_keypoints: K\n        - heatmap height: H\n        - heatmap width: W\n\n    Args:\n        heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.\n        kernel (int): Gaussian kernel size (K) for modulation, which should\n            match the heatmap gaussian sigma when training.\n            K=17 for sigma=3 and k=11 for sigma=2.\n\n    Returns:\n        np.ndarray ([N, K, H, W]): Modulated heatmap distribution.\n    \"\"\"\n    assert kernel % 2 == 1\n\n    border = (kernel - 1) // 2\n    batch_size = heatmaps.shape[0]\n    num_joints = heatmaps.shape[1]\n    height = heatmaps.shape[2]\n    width = heatmaps.shape[3]\n    for i in range(batch_size):\n        for j in range(num_joints):\n            origin_max = np.max(heatmaps[i, j])\n            dr = np.zeros((height + 2 * border, width + 2 * border),\n                          dtype=np.float32)\n            dr[border:-border, border:-border] = heatmaps[i, j].copy()\n            dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)\n            heatmaps[i, j] = dr[border:-border, border:-border].copy()\n            heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j])\n    return heatmaps\n\n\ndef keypoints_from_regression(regression_preds, center, scale, img_size):\n    \"\"\"Get final keypoint predictions from regression vectors and transform\n    them back to the image.\n\n    Note:\n        - batch_size: N\n        - num_keypoints: K\n\n    Args:\n        regression_preds (np.ndarray[N, K, 2]): model prediction.\n        center (np.ndarray[N, 2]): Center of the bounding box (x, y).\n        scale (np.ndarray[N, 2]): Scale of the bounding box\n            wrt height/width.\n        img_size (list(img_width, img_height)): model input image size.\n\n    Returns:\n        tuple:\n\n        - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.\n        - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.\n    \"\"\"\n    N, K, _ = regression_preds.shape\n    preds, maxvals = regression_preds, np.ones((N, K, 1), dtype=np.float32)\n\n    preds = preds * img_size\n\n    # Transform back to the image\n    for i in range(N):\n        preds[i] = transform_preds(preds[i], center[i], scale[i], img_size)\n\n    return preds, maxvals\n\n\ndef keypoints_from_heatmaps(heatmaps,\n                            center,\n                            scale,\n                            unbiased=False,\n                            post_process='default',\n                            kernel=11,\n                            valid_radius_factor=0.0546875,\n                            use_udp=False,\n                            target_type='GaussianHeatmap'):\n    \"\"\"Get final keypoint predictions from heatmaps and transform them back to\n    the image.\n\n    Note:\n        - batch size: N\n        - num keypoints: K\n        - heatmap height: H\n        - heatmap width: W\n\n    Args:\n        heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps.\n        center (np.ndarray[N, 2]): Center of the bounding box (x, y).\n        scale (np.ndarray[N, 2]): Scale of the bounding box\n            wrt height/width.\n        post_process (str/None): Choice of methods to post-process\n            heatmaps. Currently supported: None, 'default', 'unbiased',\n            'megvii'.\n        unbiased (bool): Option to use unbiased decoding. Mutually\n            exclusive with megvii.\n            Note: this arg is deprecated and unbiased=True can be replaced\n            by post_process='unbiased'\n            Paper ref: Zhang et al. Distribution-Aware Coordinate\n            Representation for Human Pose Estimation (CVPR 2020).\n        kernel (int): Gaussian kernel size (K) for modulation, which should\n            match the heatmap gaussian sigma when training.\n            K=17 for sigma=3 and k=11 for sigma=2.\n        valid_radius_factor (float): The radius factor of the positive area\n            in classification heatmap for UDP.\n        use_udp (bool): Use unbiased data processing.\n        target_type (str): 'GaussianHeatmap' or 'CombinedTarget'.\n            GaussianHeatmap: Classification target with gaussian distribution.\n            CombinedTarget: The combination of classification target\n            (response map) and regression target (offset map).\n            Paper ref: Huang et al. The Devil is in the Details: Delving into\n            Unbiased Data Processing for Human Pose Estimation (CVPR 2020).\n\n    Returns:\n        tuple: A tuple containing keypoint predictions and scores.\n\n        - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images.\n        - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.\n    \"\"\"\n    # Avoid being affected\n    heatmaps = heatmaps.copy()\n\n    # detect conflicts\n    if unbiased:\n        assert post_process not in [False, None, 'megvii']\n    if post_process in ['megvii', 'unbiased']:\n        assert kernel > 0\n    if use_udp:\n        assert not post_process == 'megvii'\n\n    # normalize configs\n    if post_process is False:\n        warnings.warn(\n            'post_process=False is deprecated, '\n            'please use post_process=None instead', DeprecationWarning)\n        post_process = None\n    elif post_process is True:\n        if unbiased is True:\n            warnings.warn(\n                'post_process=True, unbiased=True is deprecated,'\n                \" please use post_process='unbiased' instead\",\n                DeprecationWarning)\n            post_process = 'unbiased'\n        else:\n            warnings.warn(\n                'post_process=True, unbiased=False is deprecated, '\n                \"please use post_process='default' instead\",\n                DeprecationWarning)\n            post_process = 'default'\n    elif post_process == 'default':\n        if unbiased is True:\n            warnings.warn(\n                'unbiased=True is deprecated, please use '\n                \"post_process='unbiased' instead\", DeprecationWarning)\n            post_process = 'unbiased'\n\n    # start processing\n    if post_process == 'megvii':\n        heatmaps = _gaussian_blur(heatmaps, kernel=kernel)\n\n    N, K, H, W = heatmaps.shape\n    if use_udp:\n        if target_type.lower() == 'GaussianHeatMap'.lower():\n            preds, maxvals = _get_max_preds(heatmaps)\n            preds = post_dark_udp(preds, heatmaps, kernel=kernel)\n        elif target_type.lower() == 'CombinedTarget'.lower():\n            for person_heatmaps in heatmaps:\n                for i, heatmap in enumerate(person_heatmaps):\n                    kt = 2 * kernel + 1 if i % 3 == 0 else kernel\n                    cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap)\n            # valid radius is in direct proportion to the height of heatmap.\n            valid_radius = valid_radius_factor * H\n            offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius\n            offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius\n            heatmaps = heatmaps[:, ::3, :]\n            preds, maxvals = _get_max_preds(heatmaps)\n            index = preds[..., 0] + preds[..., 1] * W\n            index += W * H * np.arange(0, N * K / 3)\n            index = index.astype(int).reshape(N, K // 3, 1)\n            preds += np.concatenate((offset_x[index], offset_y[index]), axis=2)\n        else:\n            raise ValueError('target_type should be either '\n                             \"'GaussianHeatmap' or 'CombinedTarget'\")\n    else:\n        preds, maxvals = _get_max_preds(heatmaps)\n        if post_process == 'unbiased':  # alleviate biased coordinate\n            # apply Gaussian distribution modulation.\n            heatmaps = np.log(\n                np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10))\n            for n in range(N):\n                for k in range(K):\n                    preds[n][k] = _taylor(heatmaps[n][k], preds[n][k])\n        elif post_process is not None:\n            # add +/-0.25 shift to the predicted locations for higher acc.\n            for n in range(N):\n                for k in range(K):\n                    heatmap = heatmaps[n][k]\n                    px = int(preds[n][k][0])\n                    py = int(preds[n][k][1])\n                    if 1 < px < W - 1 and 1 < py < H - 1:\n                        diff = np.array([\n                            heatmap[py][px + 1] - heatmap[py][px - 1],\n                            heatmap[py + 1][px] - heatmap[py - 1][px]\n                        ])\n                        preds[n][k] += np.sign(diff) * .25\n                        if post_process == 'megvii':\n                            preds[n][k] += 0.5\n\n    # Transform back to the image\n    for i in range(N):\n        preds[i] = transform_preds(\n            preds[i], center[i], scale[i], [W, H], use_udp=use_udp)\n\n    if post_process == 'megvii':\n        maxvals = maxvals / 255.0 + 0.5\n\n    return preds, maxvals\n\n\ndef keypoints_from_heatmaps3d(heatmaps, center, scale):\n    \"\"\"Get final keypoint predictions from 3d heatmaps and transform them back\n    to the image.\n\n    Note:\n        - batch size: N\n        - num keypoints: K\n        - heatmap depth size: D\n        - heatmap height: H\n        - heatmap width: W\n\n    Args:\n        heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.\n        center (np.ndarray[N, 2]): Center of the bounding box (x, y).\n        scale (np.ndarray[N, 2]): Scale of the bounding box\n            wrt height/width.\n\n    Returns:\n        tuple: A tuple containing keypoint predictions and scores.\n\n        - preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \\\n            in images.\n        - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.\n    \"\"\"\n    N, K, D, H, W = heatmaps.shape\n    preds, maxvals = _get_max_preds_3d(heatmaps)\n    # Transform back to the image\n    for i in range(N):\n        preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i],\n                                          [W, H])\n    return preds, maxvals\n\n\ndef multilabel_classification_accuracy(pred, gt, mask, thr=0.5):\n    \"\"\"Get multi-label classification accuracy.\n\n    Note:\n        - batch size: N\n        - label number: L\n\n    Args:\n        pred (np.ndarray[N, L, 2]): model predicted labels.\n        gt (np.ndarray[N, L, 2]): ground-truth labels.\n        mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of\n        ground-truth labels.\n\n    Returns:\n        float: multi-label classification accuracy.\n    \"\"\"\n    # we only compute accuracy on the samples with ground-truth of all labels.\n    valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0)\n    pred, gt = pred[valid], gt[valid]\n\n    if pred.shape[0] == 0:\n        acc = 0.0  # when no sample is with gt labels, set acc to 0.\n    else:\n        # The classification of a sample is regarded as correct\n        # only if it's correct for all labels.\n        acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean()\n    return acc\n\n\n\ndef get_transform(center, scale, res, rot=0):\n    \"\"\"Generate transformation matrix.\"\"\"\n    # res: (height, width), (rows, cols)\n    crop_aspect_ratio = res[0] / float(res[1])\n    h = 200 * scale\n    w = h / crop_aspect_ratio\n    t = np.zeros((3, 3))\n    t[0, 0] = float(res[1]) / w\n    t[1, 1] = float(res[0]) / h\n    t[0, 2] = res[1] * (-float(center[0]) / w + .5)\n    t[1, 2] = res[0] * (-float(center[1]) / h + .5)\n    t[2, 2] = 1\n    if not rot == 0:\n        rot = -rot  # To match direction of rotation from cropping\n        rot_mat = np.zeros((3, 3))\n        rot_rad = rot * np.pi / 180\n        sn, cs = np.sin(rot_rad), np.cos(rot_rad)\n        rot_mat[0, :2] = [cs, -sn]\n        rot_mat[1, :2] = [sn, cs]\n        rot_mat[2, 2] = 1\n        # Need to rotate around center\n        t_mat = np.eye(3)\n        t_mat[0, 2] = -res[1] / 2\n        t_mat[1, 2] = -res[0] / 2\n        t_inv = t_mat.copy()\n        t_inv[:2, 2] *= -1\n        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))\n    return t\n\n\ndef transform(pt, center, scale, res, invert=0, rot=0):\n    \"\"\"Transform pixel location to different reference.\"\"\"\n    t = get_transform(center, scale, res, rot=rot)\n    if invert:\n        t = np.linalg.inv(t)\n    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T\n    new_pt = np.dot(t, new_pt)\n    return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1\n\n\ndef bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25):\n    \"\"\"\n    Get center and scale of bounding box from bounding box.\n    The expected format is [min_x, min_y, max_x, max_y].\n    \"\"\"\n    CROP_IMG_HEIGHT, CROP_IMG_WIDTH = input_resolution\n    CROP_ASPECT_RATIO = CROP_IMG_HEIGHT / float(CROP_IMG_WIDTH)\n\n    # center\n    center_x = (bbox[0] + bbox[2]) / 2.0\n    center_y = (bbox[1] + bbox[3]) / 2.0\n    center = np.array([center_x, center_y])\n\n    # scale\n    bbox_w = bbox[2] - bbox[0]\n    bbox_h = bbox[3] - bbox[1]\n    bbox_size = max(bbox_w * CROP_ASPECT_RATIO, bbox_h)\n\n    scale = np.array([bbox_size / CROP_ASPECT_RATIO, bbox_size]) / 200.0\n    # scale = bbox_size / 200.0\n    # adjust bounding box tightness\n    scale *= rescale\n    return center, scale\n\n\ndef crop(img, center, scale, res):\n    \"\"\"\n    Crop image according to the supplied bounding box.\n    res: [rows, cols]\n    \"\"\"\n    # Upper left point\n    ul = np.array(transform([1, 1], center, max(scale), res, invert=1)) - 1\n    # Bottom right point\n    br = np.array(transform([res[1] + 1, res[0] + 1], center, max(scale), res, invert=1)) - 1\n\n    # Padding so that when rotated proper amount of context is included\n    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)\n\n    new_shape = [br[1] - ul[1], br[0] - ul[0]]\n    if len(img.shape) > 2:\n        new_shape += [img.shape[2]]\n    new_img = np.zeros(new_shape, dtype=np.float32)\n\n    # Range to fill new array\n    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]\n    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]\n    # Range to sample from original image\n    old_x = max(0, ul[0]), min(len(img[0]), br[0])\n    old_y = max(0, ul[1]), min(len(img), br[1])\n    try:\n        new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]\n    except Exception as e:\n        print(e)\n\n    new_img = cv2.resize(new_img, (res[1], res[0]))  # (cols, rows)\n    return new_img, new_shape, (old_x, old_y), (new_x, new_y)  # , ul, br\n\n\ndef split_kp2ds_for_aa(kp2ds, ret_face=False):\n    kp2ds_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2\n    kp2ds_lhand = kp2ds[91:112]\n    kp2ds_rhand = kp2ds[112:133]\n    kp2ds_face = kp2ds[22:91]\n    if ret_face:\n        return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy(), kp2ds_face.copy()\n    return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()\n\ndef load_pose_metas_from_kp2ds_seq_list(kp2ds_seq, width, height):\n    metas = []\n    for kps in kp2ds_seq:\n        if len(kps) != 1:\n            return None\n        kps = kps[0].copy()\n        kps[:, 0] /= width\n        kps[:, 1] /= height\n        kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True)\n\n        if kp2ds_body[:, :2].min(axis=1).max() < 0:\n            kp2ds_body = last_kp2ds_body\n        last_kp2ds_body = kp2ds_body\n\n        meta = {\n            \"width\": width,\n            \"height\": height,\n            \"keypoints_body\": kp2ds_body.tolist(),\n            \"keypoints_left_hand\": kp2ds_lhand.tolist(),\n            \"keypoints_right_hand\": kp2ds_rhand.tolist(),\n            \"keypoints_face\": kp2ds_face.tolist(),\n        }\n        metas.append(meta)\n    return metas\n\n\ndef load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height):\n    metas = []\n    for kps in kp2ds_seq:\n        kps = kps.copy()\n        kps[:, 0] /= width\n        kps[:, 1] /= height\n        kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True)\n\n        # 排除全部小于0的情况\n        if kp2ds_body[:, :2].min(axis=1).max() < 0:\n            kp2ds_body = last_kp2ds_body\n        last_kp2ds_body = kp2ds_body\n\n        meta = {\n            \"width\": width,\n            \"height\": height,\n            \"keypoints_body\": kp2ds_body,\n            \"keypoints_left_hand\": kp2ds_lhand,\n            \"keypoints_right_hand\": kp2ds_rhand,\n            \"keypoints_face\": kp2ds_face,\n        }\n        metas.append(meta)\n    return metas"
  },
  {
    "path": "wan/modules/animate/preprocess/preprocess_data.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport argparse\nfrom process_pipepline import ProcessPipeline\n\n\ndef _parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"The preprocessing pipeline for Wan-animate.\"\n    )\n\n    parser.add_argument(\n        \"--ckpt_path\",\n        type=str,\n        default=None,\n        help=\"The path to the preprocessing model's checkpoint directory. \")\n\n    parser.add_argument(\n        \"--video_path\",\n        type=str,\n        default=None,\n        help=\"The path to the driving video.\")\n    parser.add_argument(\n        \"--refer_path\",\n        type=str,\n        default=None,\n        help=\"The path to the refererence image.\")\n    parser.add_argument(\n        \"--save_path\",\n        type=str,\n        default=None,\n        help=\"The path to save the processed results.\")\n    \n    parser.add_argument(\n        \"--resolution_area\",\n        type=int,\n        nargs=2,\n        default=[1280, 720],\n        help=\"The target resolution for processing, specified as [width, height]. To handle different aspect ratios, the video is resized to have a total area equivalent to width * height, while preserving the original aspect ratio.\"\n    )\n    parser.add_argument(\n        \"--fps\",\n        type=int,\n        default=30,\n        help=\"The target FPS for processing the driving video. Set to -1 to use the video's original FPS.\"\n    )\n\n    parser.add_argument(\n        \"--replace_flag\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use replacement mode.\")\n    parser.add_argument(\n        \"--retarget_flag\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use pose retargeting. Currently only supported in animation mode\")\n    parser.add_argument(\n        \"--use_flux\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to use image editing in pose retargeting. Recommended if the character in the reference image or the first frame of the driving video is not in a standard, front-facing pose\")\n    \n    # Parameters for the mask strategy in replacement mode. These control the mask's size and shape. Refer to https://arxiv.org/pdf/2502.06145\n    parser.add_argument(\n        \"--iterations\",\n        type=int,\n        default=3,\n        help=\"Number of iterations for mask dilation.\"\n    )\n    parser.add_argument(\n        \"--k\",\n        type=int,\n        default=7,\n        help=\"Number of kernel size for mask dilation.\"\n    )\n    parser.add_argument(\n        \"--w_len\",\n        type=int,\n        default=1,\n        help=\"The number of subdivisions for the grid along the 'w' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.\"\n    )\n    parser.add_argument(\n        \"--h_len\",\n        type=int,\n        default=1,\n        help=\"The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.\"\n    )\n    args = parser.parse_args()\n\n    return args\n\n\nif __name__ == '__main__':\n    args = _parse_args()\n    args_dict = vars(args)\n    print(args_dict)\n\n    assert len(args.resolution_area) == 2, \"resolution_area should be a list of two integers [width, height]\"\n    assert not args.use_flux or args.retarget_flag, \"Image editing with FLUX can only be used when pose retargeting is enabled.\"\n\n    pose2d_checkpoint_path = os.path.join(args.ckpt_path, 'pose2d/vitpose_h_wholebody.onnx')\n    det_checkpoint_path = os.path.join(args.ckpt_path, 'det/yolov10m.onnx')\n\n    sam2_checkpoint_path = os.path.join(args.ckpt_path, 'sam2/sam2_hiera_large.pt') if args.replace_flag else None\n    flux_kontext_path = os.path.join(args.ckpt_path, 'FLUX.1-Kontext-dev') if args.use_flux else None\n    process_pipeline = ProcessPipeline(det_checkpoint_path=det_checkpoint_path, pose2d_checkpoint_path=pose2d_checkpoint_path, sam_checkpoint_path=sam2_checkpoint_path, flux_kontext_path=flux_kontext_path)\n    os.makedirs(args.save_path, exist_ok=True)\n    process_pipeline(video_path=args.video_path, \n                     refer_image_path=args.refer_path, \n                     output_path=args.save_path,\n                     resolution_area=args.resolution_area,\n                     fps=args.fps,\n                     iterations=args.iterations,\n                     k=args.k,\n                     w_len=args.w_len,\n                     h_len=args.h_len,\n                     retarget_flag=args.retarget_flag,\n                     use_flux=args.use_flux,\n                     replace_flag=args.replace_flag)\n\n"
  },
  {
    "path": "wan/modules/animate/preprocess/process_pipepline.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport numpy as np\nimport shutil\nimport torch\nfrom diffusers import FluxKontextPipeline\nimport cv2\nfrom loguru import logger\nfrom PIL import Image\ntry:\n    import moviepy.editor as mpy\nexcept:\n    import moviepy as mpy\n\nfrom decord import VideoReader\nfrom pose2d import Pose2d\nfrom pose2d_utils import AAPoseMeta\nfrom utils import resize_by_area, get_frame_indices, padding_resize, get_face_bboxes, get_aug_mask, get_mask_body_img\nfrom human_visualization import draw_aapose_by_meta_new\nfrom retarget_pose import get_retarget_pose\nimport sam2.modeling.sam.transformer as transformer\ntransformer.USE_FLASH_ATTN = False\ntransformer.MATH_KERNEL_ON = True\ntransformer.OLD_GPU = True\nfrom sam_utils import build_sam2_video_predictor\n\n\nclass ProcessPipeline():\n    def __init__(self, det_checkpoint_path, pose2d_checkpoint_path, sam_checkpoint_path, flux_kontext_path):\n        self.pose2d = Pose2d(checkpoint=pose2d_checkpoint_path, detector_checkpoint=det_checkpoint_path)\n\n        model_cfg = \"sam2_hiera_l.yaml\"\n        if sam_checkpoint_path is not None:\n            self.predictor = build_sam2_video_predictor(model_cfg, sam_checkpoint_path)\n        if flux_kontext_path is not None:\n            self.flux_kontext = FluxKontextPipeline.from_pretrained(flux_kontext_path, torch_dtype=torch.bfloat16).to(\"cuda\")\n\n    def __call__(self, video_path, refer_image_path, output_path, resolution_area=[1280, 720], fps=30, iterations=3, k=7, w_len=1, h_len=1, retarget_flag=False, use_flux=False, replace_flag=False):\n        if replace_flag:\n\n            video_reader = VideoReader(video_path)\n            frame_num = len(video_reader)\n            print('frame_num: {}'.format(frame_num))\n            \n            video_fps = video_reader.get_avg_fps()\n            print('video_fps: {}'.format(video_fps))\n            print('fps: {}'.format(fps))\n\n            # TODO: Maybe we can switch to PyAV later, which can get accurate frame num\n            duration = video_reader.get_frame_timestamp(-1)[-1]      \n            expected_frame_num = int(duration * video_fps + 0.5) \n            ratio = abs((frame_num - expected_frame_num)/frame_num)         \n            if ratio > 0.1:\n                print(\"Warning: The difference between the actual number of frames and the expected number of frames is two large\")\n                frame_num = expected_frame_num\n\n            if fps == -1:\n                fps = video_fps\n\n            target_num = int(frame_num / video_fps * fps)\n            print('target_num: {}'.format(target_num))\n            idxs = get_frame_indices(frame_num, video_fps, target_num, fps)\n            frames = video_reader.get_batch(idxs).asnumpy()\n\n            frames = [resize_by_area(frame, resolution_area[0] * resolution_area[1], divisor=16) for frame in frames]\n            height, width = frames[0].shape[:2]\n            logger.info(f\"Processing pose meta\")\n\n\n            tpl_pose_metas = self.pose2d(frames)\n\n            face_images = []\n            for idx, meta in enumerate(tpl_pose_metas):\n                face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,\n                                                    image_shape=(frames[0].shape[0], frames[0].shape[1]))\n\n                x1, x2, y1, y2 = face_bbox_for_image\n                face_image = frames[idx][y1:y2, x1:x2]\n                face_image = cv2.resize(face_image, (512, 512))\n                face_images.append(face_image)\n\n            logger.info(f\"Processing reference image: {refer_image_path}\")\n            refer_img = cv2.imread(refer_image_path)\n            src_ref_path = os.path.join(output_path, 'src_ref.png')\n            shutil.copy(refer_image_path, src_ref_path)\n            refer_img = refer_img[..., ::-1]\n\n            refer_img = padding_resize(refer_img, height, width)\n            logger.info(f\"Processing template video: {video_path}\")\n            tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]\n            cond_images = []\n\n            for idx, meta in enumerate(tpl_retarget_pose_metas):\n                canvas = np.zeros_like(refer_img)\n                conditioning_image = draw_aapose_by_meta_new(canvas, meta)\n                cond_images.append(conditioning_image)\n            masks = self.get_mask(frames, 400, tpl_pose_metas)\n\n            bg_images = []\n            aug_masks = []\n\n            for frame, mask in zip(frames, masks):\n                if iterations > 0:\n                    _, each_mask = get_mask_body_img(frame, mask, iterations=iterations, k=k)\n                    each_aug_mask = get_aug_mask(each_mask, w_len=w_len, h_len=h_len)\n                else:\n                    each_aug_mask = mask\n\n                each_bg_image = frame * (1 - each_aug_mask[:, :, None])\n                bg_images.append(each_bg_image)\n                aug_masks.append(each_aug_mask)\n\n            src_face_path = os.path.join(output_path, 'src_face.mp4')\n            mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)\n\n            src_pose_path = os.path.join(output_path, 'src_pose.mp4')\n            mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)\n\n            src_bg_path = os.path.join(output_path, 'src_bg.mp4')\n            mpy.ImageSequenceClip(bg_images, fps=fps).write_videofile(src_bg_path)\n\n            aug_masks_new = [np.stack([mask * 255, mask * 255, mask * 255], axis=2) for mask in aug_masks]\n            src_mask_path = os.path.join(output_path, 'src_mask.mp4')\n            mpy.ImageSequenceClip(aug_masks_new, fps=fps).write_videofile(src_mask_path)\n            return True\n        else:\n            logger.info(f\"Processing reference image: {refer_image_path}\")\n            refer_img = cv2.imread(refer_image_path)\n            src_ref_path = os.path.join(output_path, 'src_ref.png')\n            shutil.copy(refer_image_path, src_ref_path)\n            refer_img = refer_img[..., ::-1]\n            \n            refer_img = resize_by_area(refer_img, resolution_area[0] * resolution_area[1], divisor=16)\n            \n            refer_pose_meta = self.pose2d([refer_img])[0]\n\n\n            logger.info(f\"Processing template video: {video_path}\")\n            video_reader = VideoReader(video_path)\n            frame_num = len(video_reader)\n            print('frame_num: {}'.format(frame_num))\n\n            video_fps = video_reader.get_avg_fps()\n            print('video_fps: {}'.format(video_fps))\n            print('fps: {}'.format(fps))\n\n            # TODO: Maybe we can switch to PyAV later, which can get accurate frame num\n            duration = video_reader.get_frame_timestamp(-1)[-1]      \n            expected_frame_num = int(duration * video_fps + 0.5) \n            ratio = abs((frame_num - expected_frame_num)/frame_num)         \n            if ratio > 0.1:\n                print(\"Warning: The difference between the actual number of frames and the expected number of frames is two large\")\n                frame_num = expected_frame_num\n\n            if fps == -1:\n                fps = video_fps\n                \n            target_num = int(frame_num / video_fps * fps)\n            print('target_num: {}'.format(target_num))\n            idxs = get_frame_indices(frame_num, video_fps, target_num, fps)\n            frames = video_reader.get_batch(idxs).asnumpy()\n\n            logger.info(f\"Processing pose meta\")\n\n            tpl_pose_meta0 = self.pose2d(frames[:1])[0]\n            tpl_pose_metas = self.pose2d(frames)\n\n            face_images = []\n            for idx, meta in enumerate(tpl_pose_metas):\n                face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3,\n                                                    image_shape=(frames[0].shape[0], frames[0].shape[1]))\n\n                x1, x2, y1, y2 = face_bbox_for_image\n                face_image = frames[idx][y1:y2, x1:x2]\n                face_image = cv2.resize(face_image, (512, 512))\n                face_images.append(face_image)\n\n            if retarget_flag:\n                if use_flux:\n                    tpl_prompt, refer_prompt = self.get_editing_prompts(tpl_pose_metas, refer_pose_meta)\n                    refer_input = Image.fromarray(refer_img)\n                    refer_edit = self.flux_kontext(\n                            image=refer_input,\n                            height=refer_img.shape[0],\n                            width=refer_img.shape[1],\n                            prompt=refer_prompt,\n                            guidance_scale=2.5,\n                            num_inference_steps=28,\n                        ).images[0]\n                    \n                    refer_edit = Image.fromarray(padding_resize(np.array(refer_edit), refer_img.shape[0], refer_img.shape[1]))\n                    refer_edit_path = os.path.join(output_path, 'refer_edit.png')\n                    refer_edit.save(refer_edit_path)\n                    refer_edit_pose_meta = self.pose2d([np.array(refer_edit)])[0]\n\n                    tpl_img = frames[1]\n                    tpl_input = Image.fromarray(tpl_img)\n                    \n                    tpl_edit = self.flux_kontext(\n                            image=tpl_input,\n                            height=tpl_img.shape[0],\n                            width=tpl_img.shape[1],\n                            prompt=tpl_prompt,\n                            guidance_scale=2.5,\n                            num_inference_steps=28,\n                        ).images[0]\n                    \n                    tpl_edit = Image.fromarray(padding_resize(np.array(tpl_edit), tpl_img.shape[0], tpl_img.shape[1]))\n                    tpl_edit_path = os.path.join(output_path, 'tpl_edit.png')\n                    tpl_edit.save(tpl_edit_path)\n                    tpl_edit_pose_meta0 = self.pose2d([np.array(tpl_edit)])[0]\n                    tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tpl_edit_pose_meta0, refer_edit_pose_meta)\n                else:\n                    tpl_retarget_pose_metas = get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, None, None)\n            else:\n               tpl_retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in tpl_pose_metas]\n\n            cond_images = []\n            for idx, meta in enumerate(tpl_retarget_pose_metas):\n                if retarget_flag:\n                    canvas = np.zeros_like(refer_img)\n                    conditioning_image = draw_aapose_by_meta_new(canvas, meta)\n                else:\n                    canvas = np.zeros_like(frames[0])\n                    conditioning_image = draw_aapose_by_meta_new(canvas, meta)\n                    conditioning_image = padding_resize(conditioning_image, refer_img.shape[0], refer_img.shape[1])\n\n                cond_images.append(conditioning_image)\n\n            src_face_path = os.path.join(output_path, 'src_face.mp4')\n            mpy.ImageSequenceClip(face_images, fps=fps).write_videofile(src_face_path)\n\n            src_pose_path = os.path.join(output_path, 'src_pose.mp4')\n            mpy.ImageSequenceClip(cond_images, fps=fps).write_videofile(src_pose_path)\n            return True\n\n    def get_editing_prompts(self, tpl_pose_metas, refer_pose_meta):\n        arm_visible = False\n        leg_visible = False\n        for tpl_pose_meta in tpl_pose_metas:\n            tpl_keypoints = tpl_pose_meta['keypoints_body']\n            if tpl_keypoints[3].all() != 0 or tpl_keypoints[4].all() != 0 or tpl_keypoints[6].all() != 0 or tpl_keypoints[7].all() != 0:\n                if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \\\n                    (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75):\n                    arm_visible = True\n            if tpl_keypoints[9].all() != 0 or tpl_keypoints[12].all() != 0 or tpl_keypoints[10].all() != 0 or tpl_keypoints[13].all() != 0:\n                if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \\\n                    (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75):\n                    leg_visible = True\n            if arm_visible and leg_visible:\n                break\n        \n        if leg_visible:\n            if tpl_pose_meta['width'] > tpl_pose_meta['height']:\n                tpl_prompt = \"Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image.\"\n            else:\n                tpl_prompt = \"Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image.\"\n\n            if refer_pose_meta['width'] > refer_pose_meta['height']:\n                refer_prompt = \"Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image.\"\n            else:\n                refer_prompt = \"Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image.\"\n        elif arm_visible:\n            if tpl_pose_meta['width'] > tpl_pose_meta['height']:\n                tpl_prompt = \"Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image.\"\n            else:\n                tpl_prompt = \"Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image.\"\n\n            if refer_pose_meta['width'] > refer_pose_meta['height']:\n                refer_prompt = \"Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image.\"\n            else:\n                refer_prompt = \"Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image.\"\n        else:\n            tpl_prompt = \"Change the person to face forward.\"\n            refer_prompt = \"Change the person to face forward.\"\n\n        return tpl_prompt, refer_prompt\n    \n\n    def get_mask(self, frames, th_step, kp2ds_all):\n        frame_num = len(frames)\n        if frame_num < th_step:\n            num_step = 1\n        else:\n            num_step = (frame_num + th_step) // th_step\n\n        all_mask = []\n        for index in range(num_step):\n            each_frames = frames[index * th_step:(index + 1) * th_step]\n    \n            kp2ds = kp2ds_all[index * th_step:(index + 1) * th_step]\n            if len(each_frames) > 4:\n                key_frame_num = 4\n            elif 4 >= len(each_frames) > 0:\n                key_frame_num = 1\n            else:\n                continue\n\n            key_frame_step = len(kp2ds) // key_frame_num\n            key_frame_index_list = list(range(0, len(kp2ds), key_frame_step))\n\n            key_points_index = [0, 1, 2, 5, 8, 11, 10, 13]\n            key_frame_body_points_list = []\n            for key_frame_index in key_frame_index_list:\n                keypoints_body_list = []\n                body_key_points = kp2ds[key_frame_index]['keypoints_body']\n                for each_index in key_points_index:\n                    each_keypoint = body_key_points[each_index]\n                    if None is each_keypoint:\n                        continue\n                    keypoints_body_list.append(each_keypoint)\n\n                keypoints_body = np.array(keypoints_body_list)[:, :2]\n                wh = np.array([[kp2ds[0]['width'], kp2ds[0]['height']]])\n                points = (keypoints_body * wh).astype(np.int32)\n                key_frame_body_points_list.append(points)\n\n            inference_state = self.predictor.init_state_v2(frames=each_frames)\n            self.predictor.reset_state(inference_state)\n            ann_obj_id = 1\n            for ann_frame_idx, points in zip(key_frame_index_list, key_frame_body_points_list):\n                labels = np.array([1] * points.shape[0], np.int32)\n                _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(\n                    inference_state=inference_state,\n                    frame_idx=ann_frame_idx,\n                    obj_id=ann_obj_id,\n                    points=points,\n                    labels=labels,\n                )\n\n            video_segments = {}\n            for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):\n                video_segments[out_frame_idx] = {\n                    out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()\n                    for i, out_obj_id in enumerate(out_obj_ids)\n                }\n\n            for out_frame_idx in range(len(video_segments)):\n                for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n                    out_mask = out_mask[0].astype(np.uint8)\n                    all_mask.append(out_mask)\n\n        return all_mask\n    \n    def convert_list_to_array(self, metas):\n        metas_list = []\n        for meta in metas:\n            for key, value in meta.items():\n                if type(value) is list:\n                    value = np.array(value)\n                meta[key] = value\n            metas_list.append(meta)\n        return metas_list\n\n"
  },
  {
    "path": "wan/modules/animate/preprocess/retarget_pose.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nimport numpy as np\nimport json\nfrom tqdm import tqdm\nimport math \nfrom typing import NamedTuple, List\nimport copy\nfrom pose2d_utils import AAPoseMeta\n\n\n# load skeleton name and bone lines\nkeypoint_list = [\n        \"Nose\",\n        \"Neck\",\n        \"RShoulder\",\n        \"RElbow\",\n        \"RWrist\", # No.4\n        \"LShoulder\",\n        \"LElbow\",\n        \"LWrist\", # No.7\n        \"RHip\",\n        \"RKnee\",\n        \"RAnkle\", # No.10\n        \"LHip\",\n        \"LKnee\",\n        \"LAnkle\", # No.13\n        \"REye\",\n        \"LEye\",\n        \"REar\",\n        \"LEar\",\n        \"LToe\",\n        \"RToe\",\n]\n\n\nlimbSeq = [\n    [2, 3], [2, 6],     # shoulders\n    [3, 4], [4, 5],     # left arm\n    [6, 7], [7, 8],     # right arm\n    [2, 9], [9, 10], [10, 11],    # right leg \n    [2, 12], [12, 13], [13, 14],  # left leg\n    [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], # face (nose, eyes, ears)\n    [14, 19], # left foot\n    [11, 20] #  right foot\n]\n\neps = 0.01\n\nclass Keypoint(NamedTuple):\n    x: float\n    y: float\n    score: float = 1.0\n    id: int = -1\n\n\n# for each limb, calculate src & dst bone's length\n# and calculate their ratios \ndef get_length(skeleton, limb):\n    \n    k1_index, k2_index = limb\n    \n    H, W = skeleton['height'], skeleton['width']\n    keypoints = skeleton['keypoints_body']\n    keypoint1 = keypoints[k1_index - 1]\n    keypoint2 = keypoints[k2_index - 1]\n\n    if keypoint1 is None or keypoint2 is None:\n        return None, None, None\n    \n    X = np.array([keypoint1[0], keypoint2[0]]) * float(W)\n    Y = np.array([keypoint1[1], keypoint2[1]]) * float(H)\n    length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n    \n    return X, Y, length\n\n\n\ndef get_handpose_meta(keypoints, delta, src_H, src_W):\n\n    new_keypoints = []\n\n    for idx, keypoint in enumerate(keypoints):\n        if keypoint is None:\n            new_keypoints.append(None)\n            continue\n        if keypoint.score == 0:\n            new_keypoints.append(None)\n            continue\n\n        x, y = keypoint.x, keypoint.y\n        x = int(x * src_W + delta[0])\n        y = int(y * src_H + delta[1])\n\n        new_keypoints.append(                \n                Keypoint(\n                    x=x,\n                    y=y,\n                    score=keypoint.score,\n                ))\n\n    return new_keypoints\n\n\ndef deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th = 0.5):\n\n    left_hand = []\n    right_hand = []\n\n    left_delta_x = hand_res['left'][0][0] * (l_ratio - 1) \n    left_delta_y = hand_res['left'][0][1] * (l_ratio - 1)\n\n    right_delta_x = hand_res['right'][0][0] * (r_ratio - 1)\n    right_delta_y = hand_res['right'][0][1] * (r_ratio - 1)\n\n    length = len(hand_res['left'])\n\n    for i in range(length):\n        # left hand\n        if hand_res['left'][i][2] < hand_score_th:\n            left_hand.append(\n                Keypoint(\n                    x=-1,\n                    y=-1,\n                    score=0,\n                )\n            )\n        else:\n            left_hand.append(\n                Keypoint(\n                    x=hand_res['left'][i][0] * l_ratio - left_delta_x,\n                    y=hand_res['left'][i][1] * l_ratio - left_delta_y,\n                    score = hand_res['left'][i][2]\n                )\n            )\n\n        # right hand\n        if hand_res['right'][i][2] < hand_score_th:\n            right_hand.append(\n                Keypoint(\n                    x=-1,\n                    y=-1,\n                    score=0,\n                )\n            )\n        else:\n            right_hand.append(\n                Keypoint(\n                    x=hand_res['right'][i][0] * r_ratio - right_delta_x,\n                    y=hand_res['right'][i][1] * r_ratio - right_delta_y,\n                    score = hand_res['right'][i][2]\n                )\n            )\n\n    return right_hand, left_hand\n\n\ndef get_scaled_pose(canvas, src_canvas, keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y,\n                                       rescaled_src_ground_x, body_flag, id, scale_min, threshold = 0.4):\n\n    H, W = canvas\n    src_H, src_W = src_canvas\n\n    new_length_list = [ ] \n    angle_list = [ ]\n\n    # keypoints from 0-1 to H/W range\n    for idx in range(len(keypoints)):\n        if keypoints[idx] is None or len(keypoints[idx]) == 0:\n            continue\n\n        keypoints[idx] = [keypoints[idx][0] * src_W, keypoints[idx][1] * src_H, keypoints[idx][2]]\n\n    # first traverse, get new_length_list and angle_list\n    for idx, (k1_index, k2_index) in enumerate(limbSeq):\n        keypoint1 = keypoints[k1_index - 1]\n        keypoint2 = keypoints[k2_index - 1]\n\n        if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0:\n            new_length_list.append(None)\n            angle_list.append(None)\n            continue\n\n        Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W)\n        X = np.array([keypoint1[1], keypoint2[1]]) #* float(H)\n\n        length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n\n        new_length = length * bone_ratio_list[idx]\n        angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n\n        new_length_list.append(new_length)\n        angle_list.append(angle)\n\n    # Keep foot length within 0.5x calf length\n    foot_lower_leg_ratio = 0.5\n    if new_length_list[8] != None and new_length_list[18] != None:\n        if new_length_list[18] > new_length_list[8] * foot_lower_leg_ratio:\n            new_length_list[18] = new_length_list[8] * foot_lower_leg_ratio\n\n    if new_length_list[11] != None and new_length_list[17] != None:\n        if new_length_list[17] > new_length_list[11] * foot_lower_leg_ratio:\n            new_length_list[17] = new_length_list[11] * foot_lower_leg_ratio\n\n    # second traverse, calculate new keypoints\n    rescale_keypoints = keypoints.copy()\n\n    for idx, (k1_index, k2_index) in enumerate(limbSeq):\n        # update dst_keypoints\n        start_keypoint = rescale_keypoints[k1_index - 1]\n        new_length = new_length_list[idx]\n        angle = angle_list[idx]\n\n        if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \\\n            len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0:\n            continue\n\n        # calculate end_keypoint\n        delta_x = new_length * math.cos(math.radians(angle))\n        delta_y = new_length * math.sin(math.radians(angle))\n        \n        end_keypoint_x = start_keypoint[0] - delta_x\n        end_keypoint_y = start_keypoint[1] - delta_y\n\n        # update keypoints\n        rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y, rescale_keypoints[k2_index - 1][2]]\n\n    if id == 0:\n        if body_flag == 'full_body' and rescale_keypoints[8] != None and rescale_keypoints[11] != None:\n            delta_ground_x_offset_first_frame = (rescale_keypoints[8][0] + rescale_keypoints[11][0]) / 2 - rescaled_src_ground_x\n            delta_ground_x += delta_ground_x_offset_first_frame\n        elif body_flag == 'half_body' and rescale_keypoints[1] != None:\n            delta_ground_x_offset_first_frame = rescale_keypoints[1][0] - rescaled_src_ground_x\n            delta_ground_x += delta_ground_x_offset_first_frame\n\n    # offset all keypoints\n    for idx in range(len(rescale_keypoints)):\n        if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0 :\n            continue\n        rescale_keypoints[idx][0] -= delta_ground_x\n        rescale_keypoints[idx][1] -= delta_ground_y\n\n        # rescale keypoints to original size\n        rescale_keypoints[idx][0] /= scale_min\n        rescale_keypoints[idx][1] /= scale_min\n\n    # Scale hand proportions based on body skeletal ratios\n    r_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min\n    l_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min\n    left_hand, right_hand = deal_hand_keypoints(keypoints_hand, r_ratio, l_ratio, hand_score_th = threshold)\n\n    left_hand_new = left_hand.copy()\n    right_hand_new = right_hand.copy()\n\n    if rescale_keypoints[4] == None and rescale_keypoints[7] == None:\n        pass\n\n    elif rescale_keypoints[4] == None and rescale_keypoints[7] != None:\n        right_hand_delta =  np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2])\n        right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W)\n\n    elif rescale_keypoints[4] != None and rescale_keypoints[7] == None:\n        left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2]) \n        left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W)\n\n    else:\n        # get left_hand and right_hand offset \n        left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2]) \n        right_hand_delta =  np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2])\n\n        if keypoints[4][0] != None and left_hand[0].x != -1:\n            left_hand_root_offset = np.array( ( keypoints[4][0] - left_hand[0].x * src_W,  keypoints[4][1] - left_hand[0].y * src_H))  \n            left_hand_delta += left_hand_root_offset\n\n        if keypoints[7][0] != None and right_hand[0].x != -1:\n            right_hand_root_offset = np.array( ( keypoints[7][0] - right_hand[0].x * src_W, keypoints[7][1] - right_hand[0].y * src_H))  \n            right_hand_delta += right_hand_root_offset\n\n        dis_left_hand = ((keypoints[4][0] - left_hand[0].x * src_W) ** 2 + (keypoints[4][1] - left_hand[0].y * src_H) ** 2) ** 0.5\n        dis_right_hand = ((keypoints[7][0] - left_hand[0].x * src_W) ** 2 + (keypoints[7][1] - left_hand[0].y * src_H) ** 2) ** 0.5\n\n        if dis_left_hand > dis_right_hand: \n            right_hand_new = get_handpose_meta(left_hand, right_hand_delta, src_H, src_W)\n            left_hand_new = get_handpose_meta(right_hand, left_hand_delta, src_H, src_W)\n        else:\n            left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W)\n            right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W)\n\n    # get normalized keypoints_body\n    norm_body_keypoints = [ ]\n    for body_keypoint in rescale_keypoints:\n        if body_keypoint != None:\n            norm_body_keypoints.append([body_keypoint[0] / W , body_keypoint[1] / H, body_keypoint[2]])\n        else:\n            norm_body_keypoints.append(None)\n\n    frame_info = {\n                    'height': H,\n                    'width': W,\n                    'keypoints_body': norm_body_keypoints,\n                    'keypoints_left_hand' : left_hand_new,\n                    'keypoints_right_hand' : right_hand_new,\n                }\n\n    return frame_info\n\n\ndef rescale_skeleton(H, W, keypoints, bone_ratio_list):\n\n    rescale_keypoints = keypoints.copy()\n\n    new_length_list = [ ] \n    angle_list = [ ]\n\n    # keypoints from 0-1 to H/W range\n    for idx in range(len(rescale_keypoints)):\n        if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0:\n            continue\n\n        rescale_keypoints[idx] = [rescale_keypoints[idx][0] * W, rescale_keypoints[idx][1] * H]\n\n    # first traverse, get new_length_list and angle_list\n    for idx, (k1_index, k2_index) in enumerate(limbSeq):\n        keypoint1 = rescale_keypoints[k1_index - 1]\n        keypoint2 = rescale_keypoints[k2_index - 1]\n\n        if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0:\n            new_length_list.append(None)\n            angle_list.append(None)\n            continue\n\n        Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W)\n        X = np.array([keypoint1[1], keypoint2[1]]) #* float(H)\n\n        length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n\n\n        new_length = length * bone_ratio_list[idx]\n        angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n\n        new_length_list.append(new_length)\n        angle_list.append(angle)\n\n    # # second traverse, calculate new keypoints\n    for idx, (k1_index, k2_index) in enumerate(limbSeq):\n        # update dst_keypoints\n        start_keypoint = rescale_keypoints[k1_index - 1]\n        new_length = new_length_list[idx]\n        angle = angle_list[idx]\n\n        if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \\\n            len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0:\n            continue\n\n        # calculate end_keypoint\n        delta_x = new_length * math.cos(math.radians(angle))\n        delta_y = new_length * math.sin(math.radians(angle))\n        \n        end_keypoint_x = start_keypoint[0] - delta_x\n        end_keypoint_y = start_keypoint[1] - delta_y\n\n        # update keypoints\n        rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y]\n\n    return rescale_keypoints\n\n\ndef fix_lack_keypoints_use_sym(skeleton):\n\n    keypoints = skeleton['keypoints_body']\n    H, W = skeleton['height'], skeleton['width']\n\n    limb_points_list = [\n                        [3, 4, 5],\n                        [6, 7, 8],\n                        [12, 13, 14, 19],\n                        [9, 10, 11, 20],\n    ]\n\n    for limb_points in limb_points_list:\n        miss_flag = False\n        for point in limb_points:\n            if keypoints[point - 1] is None:\n                miss_flag = True\n                continue\n            if miss_flag:\n                skeleton['keypoints_body'][point - 1] = None\n\n    repair_limb_seq_left = [\n        [3, 4], [4, 5],     # left arm\n        [12, 13], [13, 14],  # left leg\n        [14, 19] # left foot\n    ]\n\n    repair_limb_seq_right = [\n        [6, 7], [7, 8],     # right arm\n        [9, 10], [10, 11],    # right leg \n        [11, 20] # right foot\n    ]\n\n    repair_limb_seq = [repair_limb_seq_left, repair_limb_seq_right]\n\n    for idx_part, part in enumerate(repair_limb_seq):\n        for idx, limb in enumerate(part):\n\n            k1_index, k2_index = limb\n            keypoint1 = keypoints[k1_index - 1]\n            keypoint2 = keypoints[k2_index - 1]\n\n            if keypoint1 != None and keypoint2 is None:\n                # reference to symmetric limb\n                sym_limb = repair_limb_seq[1-idx_part][idx]\n                k1_index_sym, k2_index_sym = sym_limb\n                keypoint1_sym = keypoints[k1_index_sym - 1]\n                keypoint2_sym = keypoints[k2_index_sym - 1]\n                ref_length = 0\n\n                if keypoint1_sym != None and keypoint2_sym != None:\n                    X = np.array([keypoint1_sym[0], keypoint2_sym[0]]) * float(W)\n                    Y = np.array([keypoint1_sym[1], keypoint2_sym[1]]) * float(H)\n                    ref_length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n                else:\n                    ref_length_left, ref_length_right = 0, 0\n                    if keypoints[1] != None and keypoints[8] != None:\n                        X = np.array([keypoints[1][0], keypoints[8][0]]) * float(W)\n                        Y = np.array([keypoints[1][1], keypoints[8][1]]) * float(H)\n                        ref_length_left = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n                        if idx <= 1: # arms\n                            ref_length_left /= 2\n                    \n                    if keypoints[1] != None and keypoints[11] != None:\n                        X = np.array([keypoints[1][0], keypoints[11][0]]) * float(W)\n                        Y = np.array([keypoints[1][1], keypoints[11][1]]) * float(H)\n                        ref_length_right = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n                        if idx <= 1: # arms\n                            ref_length_right /= 2\n                        elif idx == 4: # foot\n                            ref_length_right /= 5\n\n                    ref_length = max(ref_length_left, ref_length_right)\n                    \n                if ref_length != 0:\n                    skeleton['keypoints_body'][k2_index - 1] = [0, 0] #init\n                    skeleton['keypoints_body'][k2_index - 1][0] = skeleton['keypoints_body'][k1_index - 1][0]\n                    skeleton['keypoints_body'][k2_index - 1][1] = skeleton['keypoints_body'][k1_index - 1][1] + ref_length / H\n    return skeleton\n\n\ndef rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list):\n\n    modify_bone_list = [\n        [0, 1],\n        [2, 4],\n        [3, 5],\n        [6, 9],\n        [7, 10],\n        [8, 11],\n        [17, 18]\n    ]\n\n    for modify_bone in modify_bone_list:\n        new_ratio = max(ratio_list[modify_bone[0]], ratio_list[modify_bone[1]])\n        ratio_list[modify_bone[0]] = new_ratio\n        ratio_list[modify_bone[1]] = new_ratio\n    \n    if ratio_list[13]!= None and ratio_list[15]!= None:\n        ratio_eye_avg = (ratio_list[13] + ratio_list[15]) / 2\n        ratio_list[13] = ratio_eye_avg\n        ratio_list[15] = ratio_eye_avg\n\n    if ratio_list[14]!= None and ratio_list[16]!= None:\n        ratio_eye_avg = (ratio_list[14] + ratio_list[16]) / 2\n        ratio_list[14] = ratio_eye_avg\n        ratio_list[16] = ratio_eye_avg\n\n    return ratio_list, src_length_list, dst_length_list\n\n\n\ndef check_full_body(keypoints, threshold = 0.4):\n\n    body_flag = 'half_body'\n\n    # 1. If ankle points exist, confidence is greater than the threshold, and points do not exceed the frame, return full_body\n    if keypoints[10] != None and keypoints[13] != None and keypoints[8] != None and keypoints[11] != None:\n        if (keypoints[10][1] <= 1 and keypoints[13][1] <= 1) and (keypoints[10][2] >= threshold and keypoints[13][2] >= threshold) and \\\n            (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold):\n            body_flag = 'full_body'\n            return body_flag\n\n    # 2. If hip points exist, return three_quarter_body\n    if (keypoints[8] != None and keypoints[11] != None):\n        if (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold):\n            body_flag = 'three_quarter_body'\n            return body_flag\n    \n    return body_flag\n\n\ndef check_full_body_both(flag1, flag2):\n    body_flag_dict = {\n        'full_body': 2,\n        'three_quarter_body' : 1,\n        'half_body': 0\n    }\n\n    body_flag_dict_reverse = {\n        2: 'full_body', \n        1: 'three_quarter_body',\n        0: 'half_body'\n    }\n\n    flag1_num = body_flag_dict[flag1]\n    flag2_num = body_flag_dict[flag2]\n    flag_both_num = min(flag1_num, flag2_num)\n    return body_flag_dict_reverse[flag_both_num]\n\n\ndef write_to_poses(data_to_json, none_idx, dst_shape, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, scale_min):\n    outputs = []\n    length = len(data_to_json)\n    for id in tqdm(range(length)):\n\n        src_height, src_width = data_to_json[id]['height'], data_to_json[id]['width']\n        width, height = dst_shape\n        keypoints = data_to_json[id]['keypoints_body']\n        for idx in range(len(keypoints)):\n            if idx in none_idx:\n                keypoints[idx] = None\n        new_keypoints = keypoints.copy()\n\n        # get hand keypoints\n        keypoints_hand = {'left' : data_to_json[id]['keypoints_left_hand'], 'right' : data_to_json[id]['keypoints_right_hand']}\n        # Normalize hand coordinates to 0-1 range\n        for hand_idx in range(len(data_to_json[id]['keypoints_left_hand'])):\n            data_to_json[id]['keypoints_left_hand'][hand_idx][0] = data_to_json[id]['keypoints_left_hand'][hand_idx][0] / src_width\n            data_to_json[id]['keypoints_left_hand'][hand_idx][1] = data_to_json[id]['keypoints_left_hand'][hand_idx][1] / src_height\n\n        for hand_idx in range(len(data_to_json[id]['keypoints_right_hand'])):\n            data_to_json[id]['keypoints_right_hand'][hand_idx][0] = data_to_json[id]['keypoints_right_hand'][hand_idx][0] / src_width\n            data_to_json[id]['keypoints_right_hand'][hand_idx][1] = data_to_json[id]['keypoints_right_hand'][hand_idx][1] / src_height\n\n        \n        frame_info = get_scaled_pose((height, width), (src_height, src_width), new_keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, id, scale_min)\n        outputs.append(frame_info)\n\n    return outputs\n\n\ndef calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag):\n    if scale_ratio_flag:\n\n        headw = max(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0]) - \\\n                    min(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0])\n        headw_edit = max(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0]) - \\\n                    min(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0])\n        headw_ratio = headw / headw_edit\n\n        _, _, shoulder = get_length(skeleton, [6,3])\n        _, _, shoulder_edit = get_length(skeleton_edit, [6,3])\n        shoulder_ratio = shoulder / shoulder_edit\n\n        return max(headw_ratio, shoulder_ratio)\n    \n    else:\n        return 1\n\n\n\ndef retarget_pose(src_skeleton, dst_skeleton, all_src_skeleton, src_skeleton_edit, dst_skeleton_edit, threshold=0.4):\n\n    if src_skeleton_edit is not None and dst_skeleton_edit is not None:\n        use_edit_for_base = True\n    else:\n        use_edit_for_base = False\n\n    src_skeleton_ori = copy.deepcopy(src_skeleton)\n\n    dst_skeleton_ori_h, dst_skeleton_ori_w = dst_skeleton['height'], dst_skeleton['width']\n    if src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][10] != None and src_skeleton['keypoints_body'][13] != None and \\\n        dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][10] != None and dst_skeleton['keypoints_body'][13] != None and \\\n            src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][10][2] > 0.5 and src_skeleton['keypoints_body'][13][2] > 0.5 and \\\n        dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][10][2] > 0.5 and dst_skeleton['keypoints_body'][13][2] > 0.5:\n\n        src_height = src_skeleton['height'] * abs(\n            (src_skeleton['keypoints_body'][10][1] + src_skeleton['keypoints_body'][13][1]) / 2 -\n            src_skeleton['keypoints_body'][0][1])\n        dst_height = dst_skeleton['height'] * abs(\n            (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][13][1]) / 2 -\n            dst_skeleton['keypoints_body'][0][1])\n        scale_min = 1.0 * src_height / dst_height\n    elif src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][8] != None and src_skeleton['keypoints_body'][11] != None and \\\n        dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][8] != None and dst_skeleton['keypoints_body'][11] != None and \\\n            src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][8][2] > 0.5 and src_skeleton['keypoints_body'][11][2] > 0.5 and \\\n        dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][8][2] > 0.5 and dst_skeleton['keypoints_body'][11][2] > 0.5:\n\n        src_height = src_skeleton['height'] * abs(\n            (src_skeleton['keypoints_body'][8][1] + src_skeleton['keypoints_body'][11][1]) / 2 -\n            src_skeleton['keypoints_body'][0][1])\n        dst_height = dst_skeleton['height'] * abs(\n            (dst_skeleton['keypoints_body'][8][1] + dst_skeleton['keypoints_body'][11][1]) / 2 -\n            dst_skeleton['keypoints_body'][0][1])\n        scale_min = 1.0 * src_height / dst_height\n    else:\n        scale_min = np.sqrt(src_skeleton['height'] * src_skeleton['width']) / np.sqrt(dst_skeleton['height'] * dst_skeleton['width'])\n    \n    if use_edit_for_base:\n        scale_ratio_flag = False\n        if src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][10] != None and src_skeleton_edit['keypoints_body'][13] != None and \\\n            dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][10] != None and dst_skeleton_edit['keypoints_body'][13] != None and \\\n                src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][10][2] > 0.5 and src_skeleton_edit['keypoints_body'][13][2] > 0.5 and \\\n            dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][10][2] > 0.5 and dst_skeleton_edit['keypoints_body'][13][2] > 0.5:\n\n            src_height_edit = src_skeleton_edit['height'] * abs(\n                (src_skeleton_edit['keypoints_body'][10][1] + src_skeleton_edit['keypoints_body'][13][1]) / 2 -\n                src_skeleton_edit['keypoints_body'][0][1])\n            dst_height_edit = dst_skeleton_edit['height'] * abs(\n                (dst_skeleton_edit['keypoints_body'][10][1] + dst_skeleton_edit['keypoints_body'][13][1]) / 2 -\n                dst_skeleton_edit['keypoints_body'][0][1])\n            scale_min_edit = 1.0 * src_height_edit / dst_height_edit\n        elif src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][8] != None and src_skeleton_edit['keypoints_body'][11] != None and \\\n            dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][8] != None and dst_skeleton_edit['keypoints_body'][11] != None and \\\n                src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][8][2] > 0.5 and src_skeleton_edit['keypoints_body'][11][2] > 0.5 and \\\n            dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][8][2] > 0.5 and dst_skeleton_edit['keypoints_body'][11][2] > 0.5:\n\n            src_height_edit = src_skeleton_edit['height'] * abs(\n                (src_skeleton_edit['keypoints_body'][8][1] + src_skeleton_edit['keypoints_body'][11][1]) / 2 -\n                src_skeleton_edit['keypoints_body'][0][1])\n            dst_height_edit = dst_skeleton_edit['height'] * abs(\n                (dst_skeleton_edit['keypoints_body'][8][1] + dst_skeleton_edit['keypoints_body'][11][1]) / 2 -\n                dst_skeleton_edit['keypoints_body'][0][1])\n            scale_min_edit = 1.0 * src_height_edit / dst_height_edit\n        else:\n            scale_min_edit = np.sqrt(src_skeleton_edit['height'] * src_skeleton_edit['width']) / np.sqrt(dst_skeleton_edit['height'] * dst_skeleton_edit['width'])\n            scale_ratio_flag = True\n        \n        # Flux may change the scale, compensate for it here\n        ratio_src = calculate_scale_ratio(src_skeleton, src_skeleton_edit, scale_ratio_flag)\n        ratio_dst = calculate_scale_ratio(dst_skeleton, dst_skeleton_edit, scale_ratio_flag)\n\n        dst_skeleton_edit['height'] = int(dst_skeleton_edit['height'] * scale_min_edit)\n        dst_skeleton_edit['width'] = int(dst_skeleton_edit['width'] * scale_min_edit)\n        for idx in range(len(dst_skeleton_edit['keypoints_left_hand'])):\n            dst_skeleton_edit['keypoints_left_hand'][idx][0] *= scale_min_edit\n            dst_skeleton_edit['keypoints_left_hand'][idx][1] *= scale_min_edit\n        for idx in range(len(dst_skeleton_edit['keypoints_right_hand'])):\n            dst_skeleton_edit['keypoints_right_hand'][idx][0] *= scale_min_edit\n            dst_skeleton_edit['keypoints_right_hand'][idx][1] *= scale_min_edit\n    \n\n    dst_skeleton['height'] = int(dst_skeleton['height'] * scale_min)\n    dst_skeleton['width'] = int(dst_skeleton['width'] * scale_min)\n    for idx in range(len(dst_skeleton['keypoints_left_hand'])):\n        dst_skeleton['keypoints_left_hand'][idx][0] *= scale_min\n        dst_skeleton['keypoints_left_hand'][idx][1] *= scale_min\n    for idx in range(len(dst_skeleton['keypoints_right_hand'])):\n        dst_skeleton['keypoints_right_hand'][idx][0] *= scale_min\n        dst_skeleton['keypoints_right_hand'][idx][1] *= scale_min\n\n\n    dst_body_flag = check_full_body(dst_skeleton['keypoints_body'], threshold)\n    src_body_flag = check_full_body(src_skeleton_ori['keypoints_body'], threshold)\n    body_flag = check_full_body_both(dst_body_flag, src_body_flag)\n    #print('body_flag: ', body_flag)\n\n    if use_edit_for_base:\n        src_skeleton_edit = fix_lack_keypoints_use_sym(src_skeleton_edit)\n        dst_skeleton_edit = fix_lack_keypoints_use_sym(dst_skeleton_edit)\n    else:\n        src_skeleton = fix_lack_keypoints_use_sym(src_skeleton)\n        dst_skeleton = fix_lack_keypoints_use_sym(dst_skeleton)\n\n    none_idx = []\n    for idx in range(len(dst_skeleton['keypoints_body'])):\n        if dst_skeleton['keypoints_body'][idx] == None or src_skeleton['keypoints_body'][idx] == None:\n            src_skeleton['keypoints_body'][idx] = None\n            dst_skeleton['keypoints_body'][idx] = None\n            none_idx.append(idx)\n\n    # get bone ratio list\n    ratio_list, src_length_list, dst_length_list = [], [], []\n    for idx, limb in enumerate(limbSeq):\n        if use_edit_for_base:\n            src_X, src_Y, src_length = get_length(src_skeleton_edit, limb)\n            dst_X, dst_Y, dst_length = get_length(dst_skeleton_edit, limb)\n\n            if src_X is None or src_Y is None or dst_X is None or dst_Y is None:\n                ratio = -1\n            else:\n                ratio = 1.0 * dst_length * ratio_dst / src_length / ratio_src\n        \n        else:\n            src_X, src_Y, src_length = get_length(src_skeleton, limb)\n            dst_X, dst_Y, dst_length = get_length(dst_skeleton, limb)\n\n            if src_X is None or src_Y is None or dst_X is None or dst_Y is None:\n                ratio = -1\n            else:\n                ratio = 1.0 * dst_length / src_length\n\n        ratio_list.append(ratio)\n        src_length_list.append(src_length)\n        dst_length_list.append(dst_length)\n    \n    for idx, ratio in enumerate(ratio_list):\n        if ratio == -1:\n            if ratio_list[0] != -1 and ratio_list[1] != -1:\n                ratio_list[idx] = (ratio_list[0] + ratio_list[1]) / 2\n\n    # Consider adding constraints when Flux fails to correct head pose, causing neck issues.\n    # if ratio_list[12] > (ratio_list[0]+ratio_list[1])/2*1.25:\n    #     ratio_list[12] = (ratio_list[0]+ratio_list[1])/2*1.25\n    \n    ratio_list, src_length_list, dst_length_list = rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list)\n\n    rescaled_src_skeleton_ori = rescale_skeleton(src_skeleton_ori['height'], src_skeleton_ori['width'],\n                                                 src_skeleton_ori['keypoints_body'], ratio_list)\n\n    # get global translation offset_x and offset_y\n    if body_flag == 'full_body':\n        #print('use foot mark.')\n        dst_ground_y = max(dst_skeleton['keypoints_body'][10][1], dst_skeleton['keypoints_body'][13][1]) * dst_skeleton[\n            'height']\n        # The midpoint between toe and ankle\n        if dst_skeleton['keypoints_body'][18] != None and dst_skeleton['keypoints_body'][19] != None:\n            right_foot_mid = (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][19][1]) / 2\n            left_foot_mid = (dst_skeleton['keypoints_body'][13][1] + dst_skeleton['keypoints_body'][18][1]) / 2\n            dst_ground_y = max(left_foot_mid, right_foot_mid) * dst_skeleton['height']\n\n        rescaled_src_ground_y = max(rescaled_src_skeleton_ori[10][1], rescaled_src_skeleton_ori[13][1])\n        delta_ground_y = rescaled_src_ground_y - dst_ground_y\n       \n        dst_ground_x = (dst_skeleton['keypoints_body'][8][0] + dst_skeleton['keypoints_body'][11][0]) * dst_skeleton[\n            'width'] / 2\n        rescaled_src_ground_x = (rescaled_src_skeleton_ori[8][0] + rescaled_src_skeleton_ori[11][0]) / 2\n        delta_ground_x = rescaled_src_ground_x - dst_ground_x\n        delta_x, delta_y = delta_ground_x, delta_ground_y\n\n    else:\n        #print('use neck mark.')\n        # use neck keypoint as mark\n        src_neck_y = rescaled_src_skeleton_ori[1][1]\n        dst_neck_y = dst_skeleton['keypoints_body'][1][1]\n        delta_neck_y = src_neck_y - dst_neck_y * dst_skeleton['height']\n\n        src_neck_x = rescaled_src_skeleton_ori[1][0]\n        dst_neck_x = dst_skeleton['keypoints_body'][1][0]\n        delta_neck_x = src_neck_x - dst_neck_x * dst_skeleton['width']\n        delta_x, delta_y = delta_neck_x, delta_neck_y\n        rescaled_src_ground_x = src_neck_x\n\n\n    dst_shape = (dst_skeleton_ori_w, dst_skeleton_ori_h)\n    output = write_to_poses(all_src_skeleton, none_idx, dst_shape, ratio_list, delta_x, delta_y,\n                                rescaled_src_ground_x, body_flag, scale_min)\n    return output\n\n\ndef get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tql_edit_pose_meta0, refer_edit_pose_meta):\n\n    for key, value in tpl_pose_meta0.items():\n        if type(value) is np.ndarray:\n            if key in ['keypoints_left_hand', 'keypoints_right_hand']:\n                value = value * np.array([[tpl_pose_meta0[\"width\"], tpl_pose_meta0[\"height\"], 1.0]])\n            if not isinstance(value, list):\n                value = value.tolist()\n        tpl_pose_meta0[key] = value\n\n    for key, value in refer_pose_meta.items():\n        if type(value) is np.ndarray:\n            if key in ['keypoints_left_hand', 'keypoints_right_hand']:\n                value = value * np.array([[refer_pose_meta[\"width\"], refer_pose_meta[\"height\"], 1.0]])\n            if not isinstance(value, list):\n                value = value.tolist()\n        refer_pose_meta[key] = value\n\n    tpl_pose_metas_new = []\n    for meta in tpl_pose_metas:\n        for key, value in meta.items():\n            if type(value) is np.ndarray:\n                if key in ['keypoints_left_hand', 'keypoints_right_hand']:\n                    value = value * np.array([[meta[\"width\"], meta[\"height\"], 1.0]])\n                if not isinstance(value, list):\n                    value = value.tolist()\n            meta[key] = value\n        tpl_pose_metas_new.append(meta)\n\n    if tql_edit_pose_meta0 is not None:\n        for key, value in tql_edit_pose_meta0.items():\n            if type(value) is np.ndarray:\n                if key in ['keypoints_left_hand', 'keypoints_right_hand']:\n                    value = value * np.array([[tql_edit_pose_meta0[\"width\"], tql_edit_pose_meta0[\"height\"], 1.0]])\n                if not isinstance(value, list):\n                    value = value.tolist()\n            tql_edit_pose_meta0[key] = value\n    \n    if refer_edit_pose_meta is not None:\n        for key, value in refer_edit_pose_meta.items():\n            if type(value) is np.ndarray:\n                if key in ['keypoints_left_hand', 'keypoints_right_hand']:\n                    value = value * np.array([[refer_edit_pose_meta[\"width\"], refer_edit_pose_meta[\"height\"], 1.0]])\n                if not isinstance(value, list):\n                    value = value.tolist()\n            refer_edit_pose_meta[key] = value\n\n    retarget_tpl_pose_metas = retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas_new, tql_edit_pose_meta0, refer_edit_pose_meta)\n\n    pose_metas = []\n    for meta in retarget_tpl_pose_metas:\n        pose_meta = AAPoseMeta()\n        width, height = meta[\"width\"], meta[\"height\"]\n        pose_meta.width = width\n        pose_meta.height = height\n        pose_meta.kps_body = np.array(meta[\"keypoints_body\"])[:, :2] * (width, height)\n        pose_meta.kps_body_p = np.array(meta[\"keypoints_body\"])[:, 2]\n\n        kps_lhand = []\n        kps_lhand_p = []\n        for each_kps_lhand in meta[\"keypoints_left_hand\"]:\n            if each_kps_lhand is not None:\n                kps_lhand.append([each_kps_lhand.x, each_kps_lhand.y])\n                kps_lhand_p.append(each_kps_lhand.score)\n            else:\n                kps_lhand.append([None, None])\n                kps_lhand_p.append(0.0)\n\n        pose_meta.kps_lhand = np.array(kps_lhand)\n        pose_meta.kps_lhand_p = np.array(kps_lhand_p)\n\n        kps_rhand = []\n        kps_rhand_p = []\n        for each_kps_rhand in meta[\"keypoints_right_hand\"]:\n            if each_kps_rhand is not None:\n                kps_rhand.append([each_kps_rhand.x, each_kps_rhand.y])\n                kps_rhand_p.append(each_kps_rhand.score)\n            else:\n                kps_rhand.append([None, None])\n                kps_rhand_p.append(0.0)\n\n        pose_meta.kps_rhand = np.array(kps_rhand)\n        pose_meta.kps_rhand_p = np.array(kps_rhand_p)\n\n        pose_metas.append(pose_meta)\n\n    return pose_metas\n\n"
  },
  {
    "path": "wan/modules/animate/preprocess/sam_utils.py",
    "content": "# Copyright (c) 2025. Your modifications here.\n# This file wraps and extends sam2.utils.misc for custom modifications.\n\nfrom sam2.utils import misc as sam2_misc\nfrom sam2.utils.misc import * \nfrom PIL import Image\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\nimport os\n\nimport logging\n\nimport torch\nfrom hydra import compose\nfrom hydra.utils import instantiate\nfrom omegaconf import OmegaConf\n\nfrom sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor\nfrom sam2.build_sam import _load_checkpoint\n\n\ndef _load_img_v2_as_tensor(img, image_size):\n    img_pil = Image.fromarray(img.astype(np.uint8))\n    img_np = np.array(img_pil.convert(\"RGB\").resize((image_size, image_size)))\n    if img_np.dtype == np.uint8:  # np.uint8 is expected for JPEG images\n        img_np = img_np / 255.0\n    else:\n        raise RuntimeError(f\"Unknown image dtype: {img_np.dtype}\")\n    img = torch.from_numpy(img_np).permute(2, 0, 1)\n    video_width, video_height = img_pil.size  # the original video size\n    return img, video_height, video_width\n\ndef load_video_frames(\n    video_path,\n    image_size,\n    offload_video_to_cpu,\n    img_mean=(0.485, 0.456, 0.406),\n    img_std=(0.229, 0.224, 0.225),\n    async_loading_frames=False,\n    frame_names=None,\n):\n    \"\"\"\n    Load the video frames from a directory of JPEG files (\"<frame_index>.jpg\" format).\n\n    The frames are resized to image_size x image_size and are loaded to GPU if\n    `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.\n\n    You can load a frame asynchronously by setting `async_loading_frames` to `True`.\n    \"\"\"\n    if isinstance(video_path, str) and os.path.isdir(video_path):\n        jpg_folder = video_path\n    else:\n        raise NotImplementedError(\"Only JPEG frames are supported at this moment\")\n    if frame_names is None:\n        frame_names = [\n            p\n            for p in os.listdir(jpg_folder)\n            if os.path.splitext(p)[-1] in [\".jpg\", \".jpeg\", \".JPG\", \".JPEG\", \".png\"]\n        ]\n        frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))\n\n    num_frames = len(frame_names)\n    if num_frames == 0:\n        raise RuntimeError(f\"no images found in {jpg_folder}\")\n    img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]\n    img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]\n    img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]\n\n    if async_loading_frames:\n        lazy_images = AsyncVideoFrameLoader(\n            img_paths, image_size, offload_video_to_cpu, img_mean, img_std\n        )\n        return lazy_images, lazy_images.video_height, lazy_images.video_width\n\n    images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)\n    for n, img_path in enumerate(tqdm(img_paths, desc=\"frame loading (JPEG)\")):\n        images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)\n    if not offload_video_to_cpu:\n        images = images.cuda()\n        img_mean = img_mean.cuda()\n        img_std = img_std.cuda()\n    # normalize by mean and std\n    images -= img_mean\n    images /= img_std\n    return images, video_height, video_width\n\n\ndef load_video_frames_v2(\n    frames,\n    image_size,\n    offload_video_to_cpu,\n    img_mean=(0.485, 0.456, 0.406),\n    img_std=(0.229, 0.224, 0.225),\n    async_loading_frames=False,\n    frame_names=None,\n):\n    \"\"\"\n    Load the video frames from a directory of JPEG files (\"<frame_index>.jpg\" format).\n\n    The frames are resized to image_size x image_size and are loaded to GPU if\n    `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.\n\n    You can load a frame asynchronously by setting `async_loading_frames` to `True`.\n    \"\"\"\n    num_frames = len(frames)\n    img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]\n    img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]\n\n    images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)\n    for n, frame in enumerate(tqdm(frames, desc=\"video frame\")):\n        images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size)\n    if not offload_video_to_cpu:\n        images = images.cuda()\n        img_mean = img_mean.cuda()\n        img_std = img_std.cuda()\n    # normalize by mean and std\n    images -= img_mean\n    images /= img_std\n    return images, video_height, video_width\n\ndef build_sam2_video_predictor(\n    config_file,\n    ckpt_path=None,\n    device=\"cuda\",\n    mode=\"eval\",\n    hydra_overrides_extra=[],\n    apply_postprocessing=True,\n):\n    hydra_overrides = [\n        \"++model._target_=video_predictor.SAM2VideoPredictor\",\n    ]\n    if apply_postprocessing:\n        hydra_overrides_extra = hydra_overrides_extra.copy()\n        hydra_overrides_extra += [\n            # dynamically fall back to multi-mask if the single mask is not stable\n            \"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true\",\n            \"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05\",\n            \"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98\",\n            # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking\n            \"++model.binarize_mask_from_pts_for_mem_enc=true\",\n            # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)\n            \"++model.fill_hole_area=8\",\n        ]\n\n    hydra_overrides.extend(hydra_overrides_extra)\n    # Read config and init model\n    cfg = compose(config_name=config_file, overrides=hydra_overrides)\n    OmegaConf.resolve(cfg)\n    model = instantiate(cfg.model, _recursive_=True)\n    _load_checkpoint(model, ckpt_path)\n    model = model.to(device)\n    if mode == \"eval\":\n        model.eval()\n    return model"
  },
  {
    "path": "wan/modules/animate/preprocess/utils.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport os\nimport cv2\nimport math\nimport random\nimport numpy as np\n\ndef get_mask_boxes(mask):\n    \"\"\"\n\n    Args:\n        mask: [h, w]\n    Returns:\n\n    \"\"\"\n    y_coords, x_coords = np.nonzero(mask)\n    x_min = x_coords.min()\n    x_max = x_coords.max()\n    y_min = y_coords.min()\n    y_max = y_coords.max()\n    bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32)\n    return bbox\n\n\ndef get_aug_mask(body_mask, w_len=10, h_len=20):\n    body_bbox = get_mask_boxes(body_mask)\n\n    bbox_wh = body_bbox[2:4] - body_bbox[0:2]\n    w_slice = np.int32(bbox_wh[0] / w_len)\n    h_slice = np.int32(bbox_wh[1] / h_len)\n\n    for each_w in range(body_bbox[0], body_bbox[2], w_slice):\n        w_start = min(each_w, body_bbox[2])\n        w_end = min((each_w + w_slice), body_bbox[2])\n        # print(w_start, w_end)\n        for each_h in range(body_bbox[1], body_bbox[3], h_slice):\n            h_start = min(each_h, body_bbox[3])\n            h_end = min((each_h + h_slice), body_bbox[3])\n            if body_mask[h_start:h_end, w_start:w_end].sum() > 0:\n                body_mask[h_start:h_end, w_start:w_end] = 1\n\n    return body_mask\n    \ndef get_mask_body_img(img_copy, hand_mask, k=7, iterations=1):\n    kernel = np.ones((k, k), np.uint8)\n    dilation = cv2.dilate(hand_mask, kernel, iterations=iterations)\n    mask_hand_img = img_copy * (1 - dilation[:, :, None])\n\n    return mask_hand_img, dilation\n\n\ndef get_face_bboxes(kp2ds, scale, image_shape, ratio_aug):\n    h, w = image_shape\n    kp2ds_face = kp2ds.copy()[23:91, :2]\n\n    min_x, min_y = np.min(kp2ds_face, axis=0)\n    max_x, max_y = np.max(kp2ds_face, axis=0)\n\n\n    initial_width = max_x - min_x\n    initial_height = max_y - min_y\n\n    initial_area = initial_width * initial_height\n\n    expanded_area = initial_area * scale\n\n    new_width = np.sqrt(expanded_area * (initial_width / initial_height))\n    new_height = np.sqrt(expanded_area * (initial_height / initial_width))\n\n    delta_width = (new_width - initial_width) / 2\n    delta_height = (new_height - initial_height) / 4\n\n    if ratio_aug:\n        if random.random() > 0.5:\n            delta_width += random.uniform(0, initial_width // 10)\n        else:\n            delta_height += random.uniform(0, initial_height // 10)\n\n    expanded_min_x = max(min_x - delta_width, 0)\n    expanded_max_x = min(max_x + delta_width, w)\n    expanded_min_y = max(min_y - 3 * delta_height, 0)\n    expanded_max_y = min(max_y + delta_height, h)\n\n    return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]\n\n\ndef calculate_new_size(orig_w, orig_h, target_area, divisor=64):\n\n    target_ratio = orig_w / orig_h\n\n    def check_valid(w, h):\n\n        if w <= 0 or h <= 0:\n            return False\n        return (w * h <= target_area and  \n                w % divisor == 0 and  \n                h % divisor == 0)  \n\n    def get_ratio_diff(w, h):\n\n        return abs(w / h - target_ratio)\n\n    def round_to_64(value, round_up=False, divisor=64):\n\n        if round_up:\n            return divisor * ((value + (divisor - 1)) // divisor)\n        return divisor * (value // divisor)\n\n    possible_sizes = []\n\n    max_area_h = int(np.sqrt(target_area / target_ratio))\n    max_area_w = int(max_area_h * target_ratio)\n\n    max_h = round_to_64(max_area_h, round_up=True, divisor=divisor)\n    max_w = round_to_64(max_area_w, round_up=True, divisor=divisor)\n\n    for h in range(divisor, max_h + divisor, divisor):\n        ideal_w = h * target_ratio\n\n        w_down = round_to_64(ideal_w)\n        w_up = round_to_64(ideal_w, round_up=True)\n\n        for w in [w_down, w_up]:\n            if check_valid(w, h, divisor):\n                possible_sizes.append((w, h, get_ratio_diff(w, h)))\n\n    if not possible_sizes:\n        raise ValueError(\"Can not find suitable size\")\n\n    possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2]))\n\n    best_w, best_h, _ = possible_sizes[0]\n    return int(best_w), int(best_h)\n\n\ndef resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0)):\n    h, w = image.shape[:2]\n    try:\n        new_w, new_h = calculate_new_size(w, h, target_area, divisor)\n    except:\n        aspect_ratio = w / h\n\n        if keep_aspect_ratio:\n            new_h = math.sqrt(target_area / aspect_ratio)\n            new_w = target_area / new_h\n        else:\n            new_w = new_h = math.sqrt(target_area)\n\n        new_w, new_h = int((new_w // divisor) * divisor), int((new_h // divisor) * divisor)\n\n    interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR\n\n    resized_image = padding_resize(image, height=new_h, width=new_w, padding_color=padding_color,\n                                    interpolation=interpolation)\n    return resized_image\n\n\ndef padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):\n    ori_height = img_ori.shape[0]\n    ori_width = img_ori.shape[1]\n    channel = img_ori.shape[2]\n\n    img_pad = np.zeros((height, width, channel))\n    if channel == 1:\n        img_pad[:, :, 0] = padding_color[0]\n    else:\n        img_pad[:, :, 0] = padding_color[0]\n        img_pad[:, :, 1] = padding_color[1]\n        img_pad[:, :, 2] = padding_color[2]\n\n    if (ori_height / ori_width) > (height / width):\n        new_width = int(height / ori_height * ori_width)\n        img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)\n        padding = int((width - new_width) / 2)\n        if len(img.shape) == 2:\n            img = img[:, :, np.newaxis]  \n        img_pad[:, padding: padding + new_width, :] = img\n    else:\n        new_height = int(width / ori_width * ori_height)\n        img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)\n        padding = int((height - new_height) / 2)\n        if len(img.shape) == 2:\n            img = img[:, :, np.newaxis]  \n        img_pad[padding: padding + new_height, :, :] = img\n\n    img_pad = np.uint8(img_pad)\n\n    return img_pad\n\n\ndef get_frame_indices(frame_num, video_fps, clip_length, train_fps):\n\n    start_frame = 0\n    times = np.arange(0, clip_length) / train_fps\n    frame_indices = start_frame + np.round(times * video_fps).astype(int)\n    frame_indices = np.clip(frame_indices, 0, frame_num - 1)\n\n    return frame_indices.tolist()\n\n\ndef get_face_bboxes(kp2ds, scale, image_shape):\n    h, w = image_shape\n    kp2ds_face = kp2ds.copy()[1:] * (w, h)\n\n    min_x, min_y = np.min(kp2ds_face, axis=0)\n    max_x, max_y = np.max(kp2ds_face, axis=0)\n\n    initial_width = max_x - min_x\n    initial_height = max_y - min_y\n\n    initial_area = initial_width * initial_height\n\n    expanded_area = initial_area * scale\n\n    new_width = np.sqrt(expanded_area * (initial_width / initial_height))\n    new_height = np.sqrt(expanded_area * (initial_height / initial_width))\n\n    delta_width = (new_width - initial_width) / 2\n    delta_height = (new_height - initial_height) / 4\n\n    expanded_min_x = max(min_x - delta_width, 0)\n    expanded_max_x = min(max_x + delta_width, w)\n    expanded_min_y = max(min_y - 3 * delta_height, 0)\n    expanded_max_y = min(max_y + delta_height, h)\n\n    return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]"
  },
  {
    "path": "wan/modules/animate/preprocess/video_predictor.py",
    "content": "# Copyright (c) 2025. Your modifications here.\n# A wrapper for sam2 functions\nfrom collections import OrderedDict\nimport torch\nfrom tqdm import tqdm\n\nfrom sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base\nfrom sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor\nfrom sam2.utils.misc import concat_points, fill_holes_in_mask_scores\n\nfrom sam_utils import load_video_frames_v2, load_video_frames\n\n\nclass SAM2VideoPredictor(_SAM2VideoPredictor):\n    def __init__(self, *args, **kwargs):\n\n        super().__init__(*args, **kwargs)\n        \n    @torch.inference_mode()\n    def init_state(\n        self,\n        video_path,\n        offload_video_to_cpu=False,\n        offload_state_to_cpu=False,\n        async_loading_frames=False,\n        frame_names=None\n    ):\n        \"\"\"Initialize a inference state.\"\"\"\n        images, video_height, video_width = load_video_frames(\n            video_path=video_path,\n            image_size=self.image_size,\n            offload_video_to_cpu=offload_video_to_cpu,\n            async_loading_frames=async_loading_frames,\n            frame_names=frame_names\n        )\n        inference_state = {}\n        inference_state[\"images\"] = images\n        inference_state[\"num_frames\"] = len(images)\n        # whether to offload the video frames to CPU memory\n        # turning on this option saves the GPU memory with only a very small overhead\n        inference_state[\"offload_video_to_cpu\"] = offload_video_to_cpu\n        # whether to offload the inference state to CPU memory\n        # turning on this option saves the GPU memory at the cost of a lower tracking fps\n        # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object\n        # and from 24 to 21 when tracking two objects)\n        inference_state[\"offload_state_to_cpu\"] = offload_state_to_cpu\n        # the original video height and width, used for resizing final output scores\n        inference_state[\"video_height\"] = video_height\n        inference_state[\"video_width\"] = video_width\n        inference_state[\"device\"] = torch.device(\"cuda\")\n        if offload_state_to_cpu:\n            inference_state[\"storage_device\"] = torch.device(\"cpu\")\n        else:\n            inference_state[\"storage_device\"] = torch.device(\"cuda\")\n        # inputs on each frame\n        inference_state[\"point_inputs_per_obj\"] = {}\n        inference_state[\"mask_inputs_per_obj\"] = {}\n        # visual features on a small number of recently visited frames for quick interactions\n        inference_state[\"cached_features\"] = {}\n        # values that don't change across frames (so we only need to hold one copy of them)\n        inference_state[\"constants\"] = {}\n        # mapping between client-side object id and model-side object index\n        inference_state[\"obj_id_to_idx\"] = OrderedDict()\n        inference_state[\"obj_idx_to_id\"] = OrderedDict()\n        inference_state[\"obj_ids\"] = []\n        # A storage to hold the model's tracking results and states on each frame\n        inference_state[\"output_dict\"] = {\n            \"cond_frame_outputs\": {},  # dict containing {frame_idx: <out>}\n            \"non_cond_frame_outputs\": {},  # dict containing {frame_idx: <out>}\n        }\n        # Slice (view) of each object tracking results, sharing the same memory with \"output_dict\"\n        inference_state[\"output_dict_per_obj\"] = {}\n        # A temporary storage to hold new outputs when user interact with a frame\n        # to add clicks or mask (it's merged into \"output_dict\" before propagation starts)\n        inference_state[\"temp_output_dict_per_obj\"] = {}\n        # Frames that already holds consolidated outputs from click or mask inputs\n        # (we directly use their consolidated outputs during tracking)\n        inference_state[\"consolidated_frame_inds\"] = {\n            \"cond_frame_outputs\": set(),  # set containing frame indices\n            \"non_cond_frame_outputs\": set(),  # set containing frame indices\n        }\n        # metadata for each tracking frame (e.g. which direction it's tracked)\n        inference_state[\"tracking_has_started\"] = False\n        inference_state[\"frames_already_tracked\"] = {}\n        # Warm up the visual backbone and cache the image feature on frame 0\n        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)\n        return inference_state\n\n    @torch.inference_mode()\n    def init_state_v2(\n            self,\n            frames,\n            offload_video_to_cpu=False,\n            offload_state_to_cpu=False,\n            async_loading_frames=False,\n            frame_names=None\n    ):\n        \"\"\"Initialize a inference state.\"\"\"\n        images, video_height, video_width = load_video_frames_v2(\n            frames=frames,\n            image_size=self.image_size,\n            offload_video_to_cpu=offload_video_to_cpu,\n            async_loading_frames=async_loading_frames,\n            frame_names=frame_names\n        )\n        inference_state = {}\n        inference_state[\"images\"] = images\n        inference_state[\"num_frames\"] = len(images)\n        # whether to offload the video frames to CPU memory\n        # turning on this option saves the GPU memory with only a very small overhead\n        inference_state[\"offload_video_to_cpu\"] = offload_video_to_cpu\n        # whether to offload the inference state to CPU memory\n        # turning on this option saves the GPU memory at the cost of a lower tracking fps\n        # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object\n        # and from 24 to 21 when tracking two objects)\n        inference_state[\"offload_state_to_cpu\"] = offload_state_to_cpu\n        # the original video height and width, used for resizing final output scores\n        inference_state[\"video_height\"] = video_height\n        inference_state[\"video_width\"] = video_width\n        inference_state[\"device\"] = torch.device(\"cuda\")\n        if offload_state_to_cpu:\n            inference_state[\"storage_device\"] = torch.device(\"cpu\")\n        else:\n            inference_state[\"storage_device\"] = torch.device(\"cuda\")\n        # inputs on each frame\n        inference_state[\"point_inputs_per_obj\"] = {}\n        inference_state[\"mask_inputs_per_obj\"] = {}\n        # visual features on a small number of recently visited frames for quick interactions\n        inference_state[\"cached_features\"] = {}\n        # values that don't change across frames (so we only need to hold one copy of them)\n        inference_state[\"constants\"] = {}\n        # mapping between client-side object id and model-side object index\n        inference_state[\"obj_id_to_idx\"] = OrderedDict()\n        inference_state[\"obj_idx_to_id\"] = OrderedDict()\n        inference_state[\"obj_ids\"] = []\n        # A storage to hold the model's tracking results and states on each frame\n        inference_state[\"output_dict\"] = {\n            \"cond_frame_outputs\": {},  # dict containing {frame_idx: <out>}\n            \"non_cond_frame_outputs\": {},  # dict containing {frame_idx: <out>}\n        }\n        # Slice (view) of each object tracking results, sharing the same memory with \"output_dict\"\n        inference_state[\"output_dict_per_obj\"] = {}\n        # A temporary storage to hold new outputs when user interact with a frame\n        # to add clicks or mask (it's merged into \"output_dict\" before propagation starts)\n        inference_state[\"temp_output_dict_per_obj\"] = {}\n        # Frames that already holds consolidated outputs from click or mask inputs\n        # (we directly use their consolidated outputs during tracking)\n        inference_state[\"consolidated_frame_inds\"] = {\n            \"cond_frame_outputs\": set(),  # set containing frame indices\n            \"non_cond_frame_outputs\": set(),  # set containing frame indices\n        }\n        # metadata for each tracking frame (e.g. which direction it's tracked)\n        inference_state[\"tracking_has_started\"] = False\n        inference_state[\"frames_already_tracked\"] = {}\n\n        # resolves KeyError: 'frames_tracked_per_obj' when using newer SAM-2 versions for running preprocessing in 'replacement mode'\n        inference_state[\"frames_tracked_per_obj\"] = {}\n\n        # Warm up the visual backbone and cache the image feature on frame 0\n        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)\n        return inference_state"
  },
  {
    "path": "wan/modules/animate/xlm_roberta.py",
    "content": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n__all__ = ['XLMRoberta', 'xlm_roberta_large']\n\n\nclass SelfAttention(nn.Module):\n\n    def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.eps = eps\n\n        # layers\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(dim, dim)\n        self.v = nn.Linear(dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x, mask):\n        \"\"\"\n        x:   [B, L, C].\n        \"\"\"\n        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)\n        k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)\n        v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)\n\n        # compute attention\n        p = self.dropout.p if self.training else 0.0\n        x = F.scaled_dot_product_attention(q, k, v, mask, p)\n        x = x.permute(0, 2, 1, 3).reshape(b, s, c)\n\n        # output\n        x = self.o(x)\n        x = self.dropout(x)\n        return x\n\n\nclass AttentionBlock(nn.Module):\n\n    def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.post_norm = post_norm\n        self.eps = eps\n\n        # layers\n        self.attn = SelfAttention(dim, num_heads, dropout, eps)\n        self.norm1 = nn.LayerNorm(dim, eps=eps)\n        self.ffn = nn.Sequential(\n            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),\n            nn.Dropout(dropout))\n        self.norm2 = nn.LayerNorm(dim, eps=eps)\n\n    def forward(self, x, mask):\n        if self.post_norm:\n            x = self.norm1(x + self.attn(x, mask))\n            x = self.norm2(x + self.ffn(x))\n        else:\n            x = x + self.attn(self.norm1(x), mask)\n            x = x + self.ffn(self.norm2(x))\n        return x\n\n\nclass XLMRoberta(nn.Module):\n    \"\"\"\n    XLMRobertaModel with no pooler and no LM head.\n    \"\"\"\n\n    def __init__(self,\n                 vocab_size=250002,\n                 max_seq_len=514,\n                 type_size=1,\n                 pad_id=1,\n                 dim=1024,\n                 num_heads=16,\n                 num_layers=24,\n                 post_norm=True,\n                 dropout=0.1,\n                 eps=1e-5):\n        super().__init__()\n        self.vocab_size = vocab_size\n        self.max_seq_len = max_seq_len\n        self.type_size = type_size\n        self.pad_id = pad_id\n        self.dim = dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.post_norm = post_norm\n        self.eps = eps\n\n        # embeddings\n        self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)\n        self.type_embedding = nn.Embedding(type_size, dim)\n        self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)\n        self.dropout = nn.Dropout(dropout)\n\n        # blocks\n        self.blocks = nn.ModuleList([\n            AttentionBlock(dim, num_heads, post_norm, dropout, eps)\n            for _ in range(num_layers)\n        ])\n\n        # norm layer\n        self.norm = nn.LayerNorm(dim, eps=eps)\n\n    def forward(self, ids):\n        \"\"\"\n        ids: [B, L] of torch.LongTensor.\n        \"\"\"\n        b, s = ids.shape\n        mask = ids.ne(self.pad_id).long()\n\n        # embeddings\n        x = self.token_embedding(ids) + \\\n            self.type_embedding(torch.zeros_like(ids)) + \\\n            self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)\n        if self.post_norm:\n            x = self.norm(x)\n        x = self.dropout(x)\n\n        # blocks\n        mask = torch.where(\n            mask.view(b, 1, 1, s).gt(0), 0.0,\n            torch.finfo(x.dtype).min)\n        for block in self.blocks:\n            x = block(x, mask)\n\n        # output\n        if not self.post_norm:\n            x = self.norm(x)\n        return x\n\n\ndef xlm_roberta_large(pretrained=False,\n                      return_tokenizer=False,\n                      device='cpu',\n                      **kwargs):\n    \"\"\"\n    XLMRobertaLarge adapted from Huggingface.\n    \"\"\"\n    # params\n    cfg = dict(\n        vocab_size=250002,\n        max_seq_len=514,\n        type_size=1,\n        pad_id=1,\n        dim=1024,\n        num_heads=16,\n        num_layers=24,\n        post_norm=True,\n        dropout=0.1,\n        eps=1e-5)\n    cfg.update(**kwargs)\n\n    # init a model on device\n    with torch.device(device):\n        model = XLMRoberta(**cfg)\n    return model"
  },
  {
    "path": "wan/modules/attention.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n    import flash_attn_interface\n    FLASH_ATTN_3_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_3_AVAILABLE = False\n\ntry:\n    import flash_attn\n    FLASH_ATTN_2_AVAILABLE = True\nexcept ModuleNotFoundError:\n    FLASH_ATTN_2_AVAILABLE = False\n\nimport warnings\n\n__all__ = [\n    'flash_attention',\n    'attention',\n]\n\n\ndef flash_attention(\n    q,\n    k,\n    v,\n    q_lens=None,\n    k_lens=None,\n    dropout_p=0.,\n    softmax_scale=None,\n    q_scale=None,\n    causal=False,\n    window_size=(-1, -1),\n    deterministic=False,\n    dtype=torch.bfloat16,\n    version=None,\n):\n    \"\"\"\n    q:              [B, Lq, Nq, C1].\n    k:              [B, Lk, Nk, C1].\n    v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.\n    q_lens:         [B].\n    k_lens:         [B].\n    dropout_p:      float. Dropout probability.\n    softmax_scale:  float. The scaling of QK^T before applying softmax.\n    causal:         bool. Whether to apply causal attention mask.\n    window_size:    (left right). If not (-1, -1), apply sliding window local attention.\n    deterministic:  bool. If True, slightly slower and uses more memory.\n    dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.\n    \"\"\"\n    half_dtypes = (torch.float16, torch.bfloat16)\n    assert dtype in half_dtypes\n    assert q.device.type == 'cuda' and q.size(-1) <= 256\n\n    # params\n    b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype\n\n    def half(x):\n        return x if x.dtype in half_dtypes else x.to(dtype)\n\n    # preprocess query\n    if q_lens is None:\n        q = half(q.flatten(0, 1))\n        q_lens = torch.tensor(\n            [lq] * b, dtype=torch.int32).to(\n                device=q.device, non_blocking=True)\n    else:\n        q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))\n\n    # preprocess key, value\n    if k_lens is None:\n        k = half(k.flatten(0, 1))\n        v = half(v.flatten(0, 1))\n        k_lens = torch.tensor(\n            [lk] * b, dtype=torch.int32).to(\n                device=k.device, non_blocking=True)\n    else:\n        k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))\n        v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))\n\n    q = q.to(v.dtype)\n    k = k.to(v.dtype)\n\n    if q_scale is not None:\n        q = q * q_scale\n\n    if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:\n        warnings.warn(\n            'Flash attention 3 is not available, use flash attention 2 instead.'\n        )\n\n    # apply attention\n    if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:\n        # Note: dropout_p, window_size are not supported in FA3 now.\n        x = flash_attn_interface.flash_attn_varlen_func(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(\n                0, dtype=torch.int32).to(q.device, non_blocking=True),\n            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(\n                0, dtype=torch.int32).to(q.device, non_blocking=True),\n            seqused_q=None,\n            seqused_k=None,\n            max_seqlen_q=lq,\n            max_seqlen_k=lk,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            deterministic=deterministic)[0].unflatten(0, (b, lq))\n    else:\n        assert FLASH_ATTN_2_AVAILABLE\n        x = flash_attn.flash_attn_varlen_func(\n            q=q,\n            k=k,\n            v=v,\n            cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(\n                0, dtype=torch.int32).to(q.device, non_blocking=True),\n            cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(\n                0, dtype=torch.int32).to(q.device, non_blocking=True),\n            max_seqlen_q=lq,\n            max_seqlen_k=lk,\n            dropout_p=dropout_p,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            window_size=window_size,\n            deterministic=deterministic).unflatten(0, (b, lq))\n\n    # output\n    return x.type(out_dtype)\n\n\ndef attention(\n    q,\n    k,\n    v,\n    q_lens=None,\n    k_lens=None,\n    dropout_p=0.,\n    softmax_scale=None,\n    q_scale=None,\n    causal=False,\n    window_size=(-1, -1),\n    deterministic=False,\n    dtype=torch.bfloat16,\n    fa_version=None,\n):\n    if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:\n        return flash_attention(\n            q=q,\n            k=k,\n            v=v,\n            q_lens=q_lens,\n            k_lens=k_lens,\n            dropout_p=dropout_p,\n            softmax_scale=softmax_scale,\n            q_scale=q_scale,\n            causal=causal,\n            window_size=window_size,\n            deterministic=deterministic,\n            dtype=dtype,\n            version=fa_version,\n        )\n    else:\n        if q_lens is not None or k_lens is not None:\n            warnings.warn(\n                'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'\n            )\n        attn_mask = None\n\n        q = q.transpose(1, 2).to(dtype)\n        k = k.transpose(1, 2).to(dtype)\n        v = v.transpose(1, 2).to(dtype)\n\n        out = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)\n\n        out = out.transpose(1, 2).contiguous()\n        return out\n"
  },
  {
    "path": "wan/modules/model.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\n\nfrom .attention import flash_attention\n\n__all__ = ['WanModel']\n\n\ndef sinusoidal_embedding_1d(dim, position):\n    # preprocess\n    assert dim % 2 == 0\n    half = dim // 2\n    position = position.type(torch.float64)\n\n    # calculation\n    sinusoid = torch.outer(\n        position, torch.pow(10000, -torch.arange(half).to(position).div(half)))\n    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)\n    return x\n\n\n@torch.amp.autocast('cuda', enabled=False)\ndef rope_params(max_seq_len, dim, theta=10000):\n    assert dim % 2 == 0\n    freqs = torch.outer(\n        torch.arange(max_seq_len),\n        1.0 / torch.pow(theta,\n                        torch.arange(0, dim, 2).to(torch.float64).div(dim)))\n    freqs = torch.polar(torch.ones_like(freqs), freqs)\n    return freqs\n\n\n@torch.amp.autocast('cuda', enabled=False)\ndef rope_apply(x, grid_sizes, freqs):\n    n, c = x.size(2), x.size(3) // 2\n\n    # split freqs\n    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)\n\n    # loop over samples\n    output = []\n    for i, (f, h, w) in enumerate(grid_sizes.tolist()):\n        seq_len = f * h * w\n\n        # precompute multipliers\n        x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(\n            seq_len, n, -1, 2))\n        freqs_i = torch.cat([\n            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),\n            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),\n            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)\n        ],\n                            dim=-1).reshape(seq_len, 1, -1)\n\n        # apply rotary embedding\n        x_i = torch.view_as_real(x_i * freqs_i).flatten(2)\n        x_i = torch.cat([x_i, x[i, seq_len:]])\n\n        # append to collection\n        output.append(x_i)\n    return torch.stack(output).float()\n\n\nclass WanRMSNorm(nn.Module):\n\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L, C]\n        \"\"\"\n        return self._norm(x.float()).type_as(x) * self.weight\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n\n\nclass WanLayerNorm(nn.LayerNorm):\n\n    def __init__(self, dim, eps=1e-6, elementwise_affine=False):\n        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)\n\n    def forward(self, x):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L, C]\n        \"\"\"\n        return super().forward(x.float()).type_as(x)\n\n\nclass WanSelfAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 num_heads,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 eps=1e-6):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.eps = eps\n\n        # layers\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(dim, dim)\n        self.v = nn.Linear(dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n\n    def forward(self, x, seq_lens, grid_sizes, freqs):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L, num_heads, C / num_heads]\n            seq_lens(Tensor): Shape [B]\n            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)\n            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]\n        \"\"\"\n        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n\n        # query, key, value function\n        def qkv_fn(x):\n            q = self.norm_q(self.q(x)).view(b, s, n, d)\n            k = self.norm_k(self.k(x)).view(b, s, n, d)\n            v = self.v(x).view(b, s, n, d)\n            return q, k, v\n\n        q, k, v = qkv_fn(x)\n\n        x = flash_attention(\n            q=rope_apply(q, grid_sizes, freqs),\n            k=rope_apply(k, grid_sizes, freqs),\n            v=v,\n            k_lens=seq_lens,\n            window_size=self.window_size)\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\nclass WanCrossAttention(WanSelfAttention):\n\n    def forward(self, x, context, context_lens):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L1, C]\n            context(Tensor): Shape [B, L2, C]\n            context_lens(Tensor): Shape [B]\n        \"\"\"\n        b, n, d = x.size(0), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.norm_q(self.q(x)).view(b, -1, n, d)\n        k = self.norm_k(self.k(context)).view(b, -1, n, d)\n        v = self.v(context).view(b, -1, n, d)\n\n        # compute attention\n        x = flash_attention(q, k, v, k_lens=context_lens)\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\nclass WanAttentionBlock(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 ffn_dim,\n                 num_heads,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 cross_attn_norm=False,\n                 eps=1e-6):\n        super().__init__()\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n\n        # layers\n        self.norm1 = WanLayerNorm(dim, eps)\n        self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,\n                                          eps)\n        self.norm3 = WanLayerNorm(\n            dim, eps,\n            elementwise_affine=True) if cross_attn_norm else nn.Identity()\n        self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,\n                                            eps)\n        self.norm2 = WanLayerNorm(dim, eps)\n        self.ffn = nn.Sequential(\n            nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),\n            nn.Linear(ffn_dim, dim))\n\n        # modulation\n        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)\n\n    def forward(\n        self,\n        x,\n        e,\n        seq_lens,\n        grid_sizes,\n        freqs,\n        context,\n        context_lens,\n    ):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L, C]\n            e(Tensor): Shape [B, L1, 6, C]\n            seq_lens(Tensor): Shape [B], length of each sequence in batch\n            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)\n            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]\n        \"\"\"\n        assert e.dtype == torch.float32\n        with torch.amp.autocast('cuda', dtype=torch.float32):\n            e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)\n        assert e[0].dtype == torch.float32\n\n        # self-attention\n        y = self.self_attn(\n            self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),\n            seq_lens, grid_sizes, freqs)\n        with torch.amp.autocast('cuda', dtype=torch.float32):\n            x = x + y * e[2].squeeze(2)\n\n        # cross-attention & ffn function\n        def cross_attn_ffn(x, context, context_lens, e):\n            x = x + self.cross_attn(self.norm3(x), context, context_lens)\n            y = self.ffn(\n                self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))\n            with torch.amp.autocast('cuda', dtype=torch.float32):\n                x = x + y * e[5].squeeze(2)\n            return x\n\n        x = cross_attn_ffn(x, context, context_lens, e)\n        return x\n\n\nclass Head(nn.Module):\n\n    def __init__(self, dim, out_dim, patch_size, eps=1e-6):\n        super().__init__()\n        self.dim = dim\n        self.out_dim = out_dim\n        self.patch_size = patch_size\n        self.eps = eps\n\n        # layers\n        out_dim = math.prod(patch_size) * out_dim\n        self.norm = WanLayerNorm(dim, eps)\n        self.head = nn.Linear(dim, out_dim)\n\n        # modulation\n        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)\n\n    def forward(self, x, e):\n        r\"\"\"\n        Args:\n            x(Tensor): Shape [B, L1, C]\n            e(Tensor): Shape [B, L1, C]\n        \"\"\"\n        assert e.dtype == torch.float32\n        with torch.amp.autocast('cuda', dtype=torch.float32):\n            e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)\n            x = (\n                self.head(\n                    self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))\n        return x\n\n\nclass WanModel(ModelMixin, ConfigMixin):\n    r\"\"\"\n    Wan diffusion backbone supporting both text-to-video and image-to-video.\n    \"\"\"\n\n    ignore_for_config = [\n        'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'\n    ]\n    _no_split_modules = ['WanAttentionBlock']\n\n    @register_to_config\n    def __init__(self,\n                 model_type='t2v',\n                 patch_size=(1, 2, 2),\n                 text_len=512,\n                 in_dim=16,\n                 dim=2048,\n                 ffn_dim=8192,\n                 freq_dim=256,\n                 text_dim=4096,\n                 out_dim=16,\n                 num_heads=16,\n                 num_layers=32,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 cross_attn_norm=True,\n                 eps=1e-6):\n        r\"\"\"\n        Initialize the diffusion model backbone.\n\n        Args:\n            model_type (`str`, *optional*, defaults to 't2v'):\n                Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)\n            patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):\n                3D patch dimensions for video embedding (t_patch, h_patch, w_patch)\n            text_len (`int`, *optional*, defaults to 512):\n                Fixed length for text embeddings\n            in_dim (`int`, *optional*, defaults to 16):\n                Input video channels (C_in)\n            dim (`int`, *optional*, defaults to 2048):\n                Hidden dimension of the transformer\n            ffn_dim (`int`, *optional*, defaults to 8192):\n                Intermediate dimension in feed-forward network\n            freq_dim (`int`, *optional*, defaults to 256):\n                Dimension for sinusoidal time embeddings\n            text_dim (`int`, *optional*, defaults to 4096):\n                Input dimension for text embeddings\n            out_dim (`int`, *optional*, defaults to 16):\n                Output video channels (C_out)\n            num_heads (`int`, *optional*, defaults to 16):\n                Number of attention heads\n            num_layers (`int`, *optional*, defaults to 32):\n                Number of transformer blocks\n            window_size (`tuple`, *optional*, defaults to (-1, -1)):\n                Window size for local attention (-1 indicates global attention)\n            qk_norm (`bool`, *optional*, defaults to True):\n                Enable query/key normalization\n            cross_attn_norm (`bool`, *optional*, defaults to False):\n                Enable cross-attention normalization\n            eps (`float`, *optional*, defaults to 1e-6):\n                Epsilon value for normalization layers\n        \"\"\"\n\n        super().__init__()\n\n        assert model_type in ['t2v', 'i2v', 'ti2v', 's2v']\n        self.model_type = model_type\n\n        self.patch_size = patch_size\n        self.text_len = text_len\n        self.in_dim = in_dim\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.freq_dim = freq_dim\n        self.text_dim = text_dim\n        self.out_dim = out_dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n\n        # embeddings\n        self.patch_embedding = nn.Conv3d(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size)\n        self.text_embedding = nn.Sequential(\n            nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),\n            nn.Linear(dim, dim))\n\n        self.time_embedding = nn.Sequential(\n            nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))\n\n        # blocks\n        self.blocks = nn.ModuleList([\n            WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,\n                              cross_attn_norm, eps) for _ in range(num_layers)\n        ])\n\n        # head\n        self.head = Head(dim, out_dim, patch_size, eps)\n\n        # buffers (don't use register_buffer otherwise dtype will be changed in to())\n        assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0\n        d = dim // num_heads\n        self.freqs = torch.cat([\n            rope_params(1024, d - 4 * (d // 6)),\n            rope_params(1024, 2 * (d // 6)),\n            rope_params(1024, 2 * (d // 6))\n        ],\n                               dim=1)\n\n        # initialize weights\n        self.init_weights()\n\n    def forward(\n        self,\n        x,\n        t,\n        context,\n        seq_len,\n        y=None,\n    ):\n        r\"\"\"\n        Forward pass through the diffusion model\n\n        Args:\n            x (List[Tensor]):\n                List of input video tensors, each with shape [C_in, F, H, W]\n            t (Tensor):\n                Diffusion timesteps tensor of shape [B]\n            context (List[Tensor]):\n                List of text embeddings each with shape [L, C]\n            seq_len (`int`):\n                Maximum sequence length for positional encoding\n            y (List[Tensor], *optional*):\n                Conditional video inputs for image-to-video mode, same shape as x\n\n        Returns:\n            List[Tensor]:\n                List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]\n        \"\"\"\n        if self.model_type == 'i2v':\n            assert y is not None\n        # params\n        device = self.patch_embedding.weight.device\n        if self.freqs.device != device:\n            self.freqs = self.freqs.to(device)\n\n        if y is not None:\n            x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]\n\n        # embeddings\n        x = [self.patch_embedding(u.unsqueeze(0)) for u in x]\n        grid_sizes = torch.stack(\n            [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])\n        x = [u.flatten(2).transpose(1, 2) for u in x]\n        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)\n        assert seq_lens.max() <= seq_len\n        x = torch.cat([\n            torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],\n                      dim=1) for u in x\n        ])\n\n        # time embeddings\n        if t.dim() == 1:\n            t = t.expand(t.size(0), seq_len)\n        with torch.amp.autocast('cuda', dtype=torch.float32):\n            bt = t.size(0)\n            t = t.flatten()\n            e = self.time_embedding(\n                sinusoidal_embedding_1d(self.freq_dim,\n                                        t).unflatten(0, (bt, seq_len)).float())\n            e0 = self.time_projection(e).unflatten(2, (6, self.dim))\n            assert e.dtype == torch.float32 and e0.dtype == torch.float32\n\n        # context\n        context_lens = None\n        context = self.text_embedding(\n            torch.stack([\n                torch.cat(\n                    [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])\n                for u in context\n            ]))\n\n        # arguments\n        kwargs = dict(\n            e=e0,\n            seq_lens=seq_lens,\n            grid_sizes=grid_sizes,\n            freqs=self.freqs,\n            context=context,\n            context_lens=context_lens)\n\n        for block in self.blocks:\n            x = block(x, **kwargs)\n\n        # head\n        x = self.head(x, e)\n\n        # unpatchify\n        x = self.unpatchify(x, grid_sizes)\n        return [u.float() for u in x]\n\n    def unpatchify(self, x, grid_sizes):\n        r\"\"\"\n        Reconstruct video tensors from patch embeddings.\n\n        Args:\n            x (List[Tensor]):\n                List of patchified features, each with shape [L, C_out * prod(patch_size)]\n            grid_sizes (Tensor):\n                Original spatial-temporal grid dimensions before patching,\n                    shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)\n\n        Returns:\n            List[Tensor]:\n                Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]\n        \"\"\"\n\n        c = self.out_dim\n        out = []\n        for u, v in zip(x, grid_sizes.tolist()):\n            u = u[:math.prod(v)].view(*v, *self.patch_size, c)\n            u = torch.einsum('fhwpqrc->cfphqwr', u)\n            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])\n            out.append(u)\n        return out\n\n    def init_weights(self):\n        r\"\"\"\n        Initialize model parameters using Xavier initialization.\n        \"\"\"\n\n        # basic init\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.xavier_uniform_(m.weight)\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n\n        # init embeddings\n        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))\n        for m in self.text_embedding.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=.02)\n        for m in self.time_embedding.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=.02)\n\n        # init output layer\n        nn.init.zeros_(self.head.head.weight)\n"
  },
  {
    "path": "wan/modules/s2v/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .audio_encoder import AudioEncoder\nfrom .model_s2v import WanModel_S2V\n\n__all__ = ['WanModel_S2V', 'AudioEncoder']\n"
  },
  {
    "path": "wan/modules/s2v/audio_encoder.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport librosa\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n\n\ndef get_sample_indices(original_fps,\n                       total_frames,\n                       target_fps,\n                       num_sample,\n                       fixed_start=None):\n    required_duration = num_sample / target_fps\n    required_origin_frames = int(np.ceil(required_duration * original_fps))\n    if required_duration > total_frames / original_fps:\n        raise ValueError(\"required_duration must be less than video length\")\n\n    if not fixed_start is None and fixed_start >= 0:\n        start_frame = fixed_start\n    else:\n        max_start = total_frames - required_origin_frames\n        if max_start < 0:\n            raise ValueError(\"video length is too short\")\n        start_frame = np.random.randint(0, max_start + 1)\n    start_time = start_frame / original_fps\n\n    end_time = start_time + required_duration\n    time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)\n\n    frame_indices = np.round(np.array(time_points) * original_fps).astype(int)\n    frame_indices = np.clip(frame_indices, 0, total_frames - 1)\n    return frame_indices\n\n\ndef linear_interpolation(features, input_fps, output_fps, output_len=None):\n    \"\"\"\n    features: shape=[1, T, 512]\n    input_fps: fps for audio, f_a\n    output_fps: fps for video, f_m\n    output_len: video length\n    \"\"\"\n    features = features.transpose(1, 2)  # [1, 512, T]\n    seq_len = features.shape[2] / float(input_fps)  # T/f_a\n    if output_len is None:\n        output_len = int(seq_len * output_fps)  # f_m*T/f_a\n    output_features = F.interpolate(\n        features, size=output_len, align_corners=True,\n        mode='linear')  # [1, 512, output_len]\n    return output_features.transpose(1, 2)  # [1, output_len, 512]\n\n\nclass AudioEncoder():\n\n    def __init__(self, device='cpu', model_id=\"facebook/wav2vec2-base-960h\"):\n        # load pretrained model\n        self.processor = Wav2Vec2Processor.from_pretrained(model_id)\n        self.model = Wav2Vec2ForCTC.from_pretrained(model_id)\n\n        self.model = self.model.to(device)\n\n        self.video_rate = 30\n\n    def extract_audio_feat(self,\n                           audio_path,\n                           return_all_layers=False,\n                           dtype=torch.float32):\n        audio_input, sample_rate = librosa.load(audio_path, sr=16000)\n\n        input_values = self.processor(\n            audio_input, sampling_rate=sample_rate,\n            return_tensors=\"pt\").input_values\n\n        # INFERENCE\n\n        # retrieve logits & take argmax\n        res = self.model(\n            input_values.to(self.model.device), output_hidden_states=True)\n        if return_all_layers:\n            feat = torch.cat(res.hidden_states)\n        else:\n            feat = res.hidden_states[-1]\n        feat = linear_interpolation(\n            feat, input_fps=50, output_fps=self.video_rate)\n\n        z = feat.to(dtype)  # Encoding for the motion\n        return z\n\n    def get_audio_embed_bucket(self,\n                               audio_embed,\n                               stride=2,\n                               batch_frames=12,\n                               m=2):\n        num_layers, audio_frame_num, audio_dim = audio_embed.shape\n\n        if num_layers > 1:\n            return_all_layers = True\n        else:\n            return_all_layers = False\n\n        min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1\n\n        bucket_num = min_batch_num * batch_frames\n        batch_idx = [stride * i for i in range(bucket_num)]\n        batch_audio_eb = []\n        for bi in batch_idx:\n            if bi < audio_frame_num:\n                audio_sample_stride = 2\n                chosen_idx = list(\n                    range(bi - m * audio_sample_stride,\n                          bi + (m + 1) * audio_sample_stride,\n                          audio_sample_stride))\n                chosen_idx = [0 if c < 0 else c for c in chosen_idx]\n                chosen_idx = [\n                    audio_frame_num - 1 if c >= audio_frame_num else c\n                    for c in chosen_idx\n                ]\n\n                if return_all_layers:\n                    frame_audio_embed = audio_embed[:, chosen_idx].flatten(\n                        start_dim=-2, end_dim=-1)\n                else:\n                    frame_audio_embed = audio_embed[0][chosen_idx].flatten()\n            else:\n                frame_audio_embed = \\\n                torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \\\n                    else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)\n            batch_audio_eb.append(frame_audio_embed)\n        batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],\n                                   dim=0)\n\n        return batch_audio_eb, min_batch_num\n\n    def get_audio_embed_bucket_fps(self,\n                                   audio_embed,\n                                   fps=16,\n                                   batch_frames=81,\n                                   m=0):\n        num_layers, audio_frame_num, audio_dim = audio_embed.shape\n\n        if num_layers > 1:\n            return_all_layers = True\n        else:\n            return_all_layers = False\n\n        scale = self.video_rate / fps\n\n        min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1\n\n        bucket_num = min_batch_num * batch_frames\n        padd_audio_num = math.ceil(min_batch_num * batch_frames / fps *\n                                   self.video_rate) - audio_frame_num\n        batch_idx = get_sample_indices(\n            original_fps=self.video_rate,\n            total_frames=audio_frame_num + padd_audio_num,\n            target_fps=fps,\n            num_sample=bucket_num,\n            fixed_start=0)\n        batch_audio_eb = []\n        audio_sample_stride = int(self.video_rate / fps)\n        for bi in batch_idx:\n            if bi < audio_frame_num:\n\n                chosen_idx = list(\n                    range(bi - m * audio_sample_stride,\n                          bi + (m + 1) * audio_sample_stride,\n                          audio_sample_stride))\n                chosen_idx = [0 if c < 0 else c for c in chosen_idx]\n                chosen_idx = [\n                    audio_frame_num - 1 if c >= audio_frame_num else c\n                    for c in chosen_idx\n                ]\n\n                if return_all_layers:\n                    frame_audio_embed = audio_embed[:, chosen_idx].flatten(\n                        start_dim=-2, end_dim=-1)\n                else:\n                    frame_audio_embed = audio_embed[0][chosen_idx].flatten()\n            else:\n                frame_audio_embed = \\\n                torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \\\n                    else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)\n            batch_audio_eb.append(frame_audio_embed)\n        batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],\n                                   dim=0)\n\n        return batch_audio_eb, min_batch_num\n"
  },
  {
    "path": "wan/modules/s2v/audio_utils.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nfrom typing import Tuple, Union\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn as nn\nfrom diffusers.models.attention import AdaLayerNorm\n\nfrom ..model import WanAttentionBlock, WanCrossAttention\nfrom .auxi_blocks import MotionEncoder_tc\n\n\nclass CausalAudioEncoder(nn.Module):\n\n    def __init__(self,\n                 dim=5120,\n                 num_layers=25,\n                 out_dim=2048,\n                 video_rate=8,\n                 num_token=4,\n                 need_global=False):\n        super().__init__()\n        self.encoder = MotionEncoder_tc(\n            in_dim=dim,\n            hidden_dim=out_dim,\n            num_heads=num_token,\n            need_global=need_global)\n        weight = torch.ones((1, num_layers, 1, 1)) * 0.01\n\n        self.weights = torch.nn.Parameter(weight)\n        self.act = torch.nn.SiLU()\n\n    def forward(self, features):\n        with amp.autocast(dtype=torch.float32):\n            # features B * num_layers * dim * video_length\n            weights = self.act(self.weights)\n            weights_sum = weights.sum(dim=1, keepdims=True)\n            weighted_feat = ((features * weights) / weights_sum).sum(\n                dim=1)  # b dim f\n            weighted_feat = weighted_feat.permute(0, 2, 1)  # b f dim\n            res = self.encoder(weighted_feat)  # b f n dim\n\n        return res  # b f n dim\n\n\nclass AudioCrossAttention(WanCrossAttention):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n\nclass AudioInjector_WAN(nn.Module):\n\n    def __init__(self,\n                 all_modules,\n                 all_modules_names,\n                 dim=2048,\n                 num_heads=32,\n                 inject_layer=[0, 27],\n                 root_net=None,\n                 enable_adain=False,\n                 adain_dim=2048,\n                 need_adain_ont=False):\n        super().__init__()\n        num_injector_layers = len(inject_layer)\n        self.injected_block_id = {}\n        audio_injector_id = 0\n        for mod_name, mod in zip(all_modules_names, all_modules):\n            if isinstance(mod, WanAttentionBlock):\n                for inject_id in inject_layer:\n                    if f'transformer_blocks.{inject_id}' in mod_name:\n                        self.injected_block_id[inject_id] = audio_injector_id\n                        audio_injector_id += 1\n\n        self.injector = nn.ModuleList([\n            AudioCrossAttention(\n                dim=dim,\n                num_heads=num_heads,\n                qk_norm=True,\n            ) for _ in range(audio_injector_id)\n        ])\n        self.injector_pre_norm_feat = nn.ModuleList([\n            nn.LayerNorm(\n                dim,\n                elementwise_affine=False,\n                eps=1e-6,\n            ) for _ in range(audio_injector_id)\n        ])\n        self.injector_pre_norm_vec = nn.ModuleList([\n            nn.LayerNorm(\n                dim,\n                elementwise_affine=False,\n                eps=1e-6,\n            ) for _ in range(audio_injector_id)\n        ])\n        if enable_adain:\n            self.injector_adain_layers = nn.ModuleList([\n                AdaLayerNorm(\n                    output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1)\n                for _ in range(audio_injector_id)\n            ])\n            if need_adain_ont:\n                self.injector_adain_output_layers = nn.ModuleList(\n                    [nn.Linear(dim, dim) for _ in range(audio_injector_id)])\n"
  },
  {
    "path": "wan/modules/s2v/auxi_blocks.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport importlib.metadata\nimport math\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models import ModelMixin\nfrom diffusers.utils import is_torch_version, logging\nfrom einops import rearrange\n\ntry:\n    from flash_attn import flash_attn_func, flash_attn_qkvpacked_func\nexcept ImportError:\n    flash_attn_func = None\n\nMEMORY_LAYOUT = {\n    \"flash\": (\n        lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),\n        lambda x: x,\n    ),\n    \"torch\": (\n        lambda x: x.transpose(1, 2),\n        lambda x: x.transpose(1, 2),\n    ),\n    \"vanilla\": (\n        lambda x: x.transpose(1, 2),\n        lambda x: x.transpose(1, 2),\n    ),\n}\n\n\ndef attention(\n    q,\n    k,\n    v,\n    mode=\"flash\",\n    drop_rate=0,\n    attn_mask=None,\n    causal=False,\n    max_seqlen_q=None,\n    batch_size=1,\n):\n    \"\"\"\n    Perform QKV self attention.\n\n    Args:\n        q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.\n        k (torch.Tensor): Key tensor with shape [b, s1, a, d]\n        v (torch.Tensor): Value tensor with shape [b, s1, a, d]\n        mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.\n        drop_rate (float): Dropout rate in attention map. (default: 0)\n        attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).\n            (default: None)\n        causal (bool): Whether to use causal attention. (default: False)\n        cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,\n            used to index into q.\n        cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,\n            used to index into kv.\n        max_seqlen_q (int): The maximum sequence length in the batch of q.\n        max_seqlen_kv (int): The maximum sequence length in the batch of k and v.\n\n    Returns:\n        torch.Tensor: Output tensor after self attention with shape [b, s, ad]\n    \"\"\"\n    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]\n\n    if mode == \"torch\":\n        if attn_mask is not None and attn_mask.dtype != torch.bool:\n            attn_mask = attn_mask.to(q.dtype)\n        x = F.scaled_dot_product_attention(\n            q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)\n    elif mode == \"flash\":\n        x = flash_attn_func(\n            q,\n            k,\n            v,\n        )\n        # x with shape [(bxs), a, d]\n        x = x.view(batch_size, max_seqlen_q, x.shape[-2],\n                   x.shape[-1])  # reshape x to [b, s, a, d]\n    elif mode == \"vanilla\":\n        scale_factor = 1 / math.sqrt(q.size(-1))\n\n        b, a, s, _ = q.shape\n        s1 = k.size(2)\n        attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)\n        if causal:\n            # Only applied to self attention\n            assert (\n                attn_mask\n                is None), \"Causal mask and attn_mask cannot be used together\"\n            temp_mask = torch.ones(\n                b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)\n            attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n            attn_bias.to(q.dtype)\n\n        if attn_mask is not None:\n            if attn_mask.dtype == torch.bool:\n                attn_bias.masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n            else:\n                attn_bias += attn_mask\n\n        # TODO: Maybe force q and k to be float32 to avoid numerical overflow\n        attn = (q @ k.transpose(-2, -1)) * scale_factor\n        attn += attn_bias\n        attn = attn.softmax(dim=-1)\n        attn = torch.dropout(attn, p=drop_rate, train=True)\n        x = attn @ v\n    else:\n        raise NotImplementedError(f\"Unsupported attention mode: {mode}\")\n\n    x = post_attn_layout(x)\n    b, s, a, d = x.shape\n    out = x.reshape(b, s, -1)\n    return out\n\n\nclass CausalConv1d(nn.Module):\n\n    def __init__(self,\n                 chan_in,\n                 chan_out,\n                 kernel_size=3,\n                 stride=1,\n                 dilation=1,\n                 pad_mode='replicate',\n                 **kwargs):\n        super().__init__()\n\n        self.pad_mode = pad_mode\n        padding = (kernel_size - 1, 0)  # T\n        self.time_causal_padding = padding\n\n        self.conv = nn.Conv1d(\n            chan_in,\n            chan_out,\n            kernel_size,\n            stride=stride,\n            dilation=dilation,\n            **kwargs)\n\n    def forward(self, x):\n        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)\n        return self.conv(x)\n\n\nclass MotionEncoder_tc(nn.Module):\n\n    def __init__(self,\n                 in_dim: int,\n                 hidden_dim: int,\n                 num_heads=int,\n                 need_global=True,\n                 dtype=None,\n                 device=None):\n        factory_kwargs = {\"dtype\": dtype, \"device\": device}\n        super().__init__()\n\n        self.num_heads = num_heads\n        self.need_global = need_global\n        self.conv1_local = CausalConv1d(\n            in_dim, hidden_dim // 4 * num_heads, 3, stride=1)\n        if need_global:\n            self.conv1_global = CausalConv1d(\n                in_dim, hidden_dim // 4, 3, stride=1)\n        self.norm1 = nn.LayerNorm(\n            hidden_dim // 4,\n            elementwise_affine=False,\n            eps=1e-6,\n            **factory_kwargs)\n        self.act = nn.SiLU()\n        self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)\n        self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)\n\n        if need_global:\n            self.final_linear = nn.Linear(hidden_dim, hidden_dim,\n                                          **factory_kwargs)\n\n        self.norm1 = nn.LayerNorm(\n            hidden_dim // 4,\n            elementwise_affine=False,\n            eps=1e-6,\n            **factory_kwargs)\n\n        self.norm2 = nn.LayerNorm(\n            hidden_dim // 2,\n            elementwise_affine=False,\n            eps=1e-6,\n            **factory_kwargs)\n\n        self.norm3 = nn.LayerNorm(\n            hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)\n\n        self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))\n\n    def forward(self, x):\n        x = rearrange(x, 'b t c -> b c t')\n        x_ori = x.clone()\n        b, c, t = x.shape\n        x = self.conv1_local(x)\n        x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)\n        x = self.norm1(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv2(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm2(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv3(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm3(x)\n        x = self.act(x)\n        x = rearrange(x, '(b n) t c -> b t n c', b=b)\n        padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)\n        x = torch.cat([x, padding], dim=-2)\n        x_local = x.clone()\n\n        if not self.need_global:\n            return x_local\n\n        x = self.conv1_global(x_ori)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm1(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv2(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm2(x)\n        x = self.act(x)\n        x = rearrange(x, 'b t c -> b c t')\n        x = self.conv3(x)\n        x = rearrange(x, 'b c t -> b t c')\n        x = self.norm3(x)\n        x = self.act(x)\n        x = self.final_linear(x)\n        x = rearrange(x, '(b n) t c -> b t n c', b=b)\n\n        return x, x_local\n"
  },
  {
    "path": "wan/modules/s2v/model_s2v.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nimport types\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn as nn\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom einops import rearrange\n\nfrom ...distributed.sequence_parallel import (\n    distributed_attention,\n    gather_forward,\n    get_rank,\n    get_world_size,\n)\nfrom ..model import (\n    Head,\n    WanAttentionBlock,\n    WanLayerNorm,\n    WanModel,\n    WanSelfAttention,\n    flash_attention,\n    rope_params,\n    sinusoidal_embedding_1d,\n)\nfrom .audio_utils import AudioInjector_WAN, CausalAudioEncoder\nfrom .motioner import FramePackMotioner, MotionerTransformers\nfrom .s2v_utils import rope_precompute\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef torch_dfs(model: nn.Module, parent_name='root'):\n    module_names, modules = [], []\n    current_name = parent_name if parent_name else 'root'\n    module_names.append(current_name)\n    modules.append(model)\n\n    for name, child in model.named_children():\n        if parent_name:\n            child_name = f'{parent_name}.{name}'\n        else:\n            child_name = name\n        child_modules, child_names = torch_dfs(child, child_name)\n        module_names += child_names\n        modules += child_modules\n    return modules, module_names\n\n\n@amp.autocast(enabled=False)\ndef rope_apply(x, grid_sizes, freqs, start=None):\n    n, c = x.size(2), x.size(3) // 2\n    # loop over samples\n    output = []\n    for i, _ in enumerate(x):\n        s = x.size(1)\n        x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(\n            s, n, -1, 2))\n        freqs_i = freqs[i, :s]\n        # apply rotary embedding\n        x_i = torch.view_as_real(x_i * freqs_i).flatten(2)\n        x_i = torch.cat([x_i, x[i, s:]])\n        # append to collection\n        output.append(x_i)\n    return torch.stack(output).float()\n\n\n@amp.autocast(enabled=False)\ndef rope_apply_usp(x, grid_sizes, freqs):\n    s, n, c = x.size(1), x.size(2), x.size(3) // 2\n    # loop over samples\n    output = []\n    for i, _ in enumerate(x):\n        s = x.size(1)\n        # precompute multipliers\n        x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(\n            s, n, -1, 2))\n        freqs_i = freqs[i]\n        freqs_i_rank = freqs_i\n        x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)\n        x_i = torch.cat([x_i, x[i, s:]])\n        # append to collection\n        output.append(x_i)\n    return torch.stack(output).float()\n\n\ndef sp_attn_forward_s2v(self,\n                        x,\n                        seq_lens,\n                        grid_sizes,\n                        freqs,\n                        dtype=torch.bfloat16):\n    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n    half_dtypes = (torch.float16, torch.bfloat16)\n\n    def half(x):\n        return x if x.dtype in half_dtypes else x.to(dtype)\n\n    # query, key, value function\n    def qkv_fn(x):\n        q = self.norm_q(self.q(x)).view(b, s, n, d)\n        k = self.norm_k(self.k(x)).view(b, s, n, d)\n        v = self.v(x).view(b, s, n, d)\n        return q, k, v\n\n    q, k, v = qkv_fn(x)\n    q = rope_apply_usp(q, grid_sizes, freqs)\n    k = rope_apply_usp(k, grid_sizes, freqs)\n\n    x = distributed_attention(\n        half(q),\n        half(k),\n        half(v),\n        seq_lens,\n        window_size=self.window_size,\n    )\n\n    # output\n    x = x.flatten(2)\n    x = self.o(x)\n    return x\n\n\nclass Head_S2V(Head):\n\n    def forward(self, x, e):\n        \"\"\"\n        Args:\n            x(Tensor): Shape [B, L1, C]\n            e(Tensor): Shape [B, L1, C]\n        \"\"\"\n        assert e.dtype == torch.float32\n        with amp.autocast(dtype=torch.float32):\n            e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)\n            x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))\n        return x\n\n\nclass WanS2VSelfAttention(WanSelfAttention):\n\n    def forward(self, x, seq_lens, grid_sizes, freqs):\n        \"\"\"\n        Args:\n            x(Tensor): Shape [B, L, num_heads, C / num_heads]\n            seq_lens(Tensor): Shape [B]\n            grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)\n            freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]\n        \"\"\"\n        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n\n        # query, key, value function\n        def qkv_fn(x):\n            q = self.norm_q(self.q(x)).view(b, s, n, d)\n            k = self.norm_k(self.k(x)).view(b, s, n, d)\n            v = self.v(x).view(b, s, n, d)\n            return q, k, v\n\n        q, k, v = qkv_fn(x)\n\n        x = flash_attention(\n            q=rope_apply(q, grid_sizes, freqs),\n            k=rope_apply(k, grid_sizes, freqs),\n            v=v,\n            k_lens=seq_lens,\n            window_size=self.window_size)\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\nclass WanS2VAttentionBlock(WanAttentionBlock):\n\n    def __init__(self,\n                 dim,\n                 ffn_dim,\n                 num_heads,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 cross_attn_norm=False,\n                 eps=1e-6):\n        super().__init__(dim, ffn_dim, num_heads, window_size, qk_norm,\n                         cross_attn_norm, eps)\n        self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,\n                                             qk_norm, eps)\n\n    def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens):\n        assert e[0].dtype == torch.float32\n        seg_idx = e[1].item()\n        seg_idx = min(max(0, seg_idx), x.size(1))\n        seg_idx = [0, seg_idx, x.size(1)]\n        e = e[0]\n        modulation = self.modulation.unsqueeze(2)\n        with amp.autocast(dtype=torch.float32):\n            e = (modulation + e).chunk(6, dim=1)\n        assert e[0].dtype == torch.float32\n\n        e = [element.squeeze(1) for element in e]\n        norm_x = self.norm1(x).float()\n        parts = []\n        for i in range(2):\n            parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *\n                         (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])\n        norm_x = torch.cat(parts, dim=1)\n        # self-attention\n        y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)\n        with amp.autocast(dtype=torch.float32):\n            z = []\n            for i in range(2):\n                z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])\n            y = torch.cat(z, dim=1)\n            x = x + y\n        # cross-attention & ffn function\n        def cross_attn_ffn(x, context, context_lens, e):\n            x = x + self.cross_attn(self.norm3(x), context, context_lens)\n            norm2_x = self.norm2(x).float()\n            parts = []\n            for i in range(2):\n                parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *\n                             (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])\n            norm2_x = torch.cat(parts, dim=1)\n            y = self.ffn(norm2_x)\n            with amp.autocast(dtype=torch.float32):\n                z = []\n                for i in range(2):\n                    z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])\n                y = torch.cat(z, dim=1)\n                x = x + y\n            return x\n\n        x = cross_attn_ffn(x, context, context_lens, e)\n        return x\n\n\nclass WanModel_S2V(ModelMixin, ConfigMixin):\n    ignore_for_config = [\n        'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',\n        'text_dim', 'window_size'\n    ]\n    _no_split_modules = ['WanS2VAttentionBlock']\n\n    @register_to_config\n    def __init__(\n            self,\n            cond_dim=0,\n            audio_dim=5120,\n            num_audio_token=4,\n            enable_adain=False,\n            adain_mode=\"attn_norm\",\n            audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],\n            zero_init=False,\n            zero_timestep=False,\n            enable_motioner=True,\n            add_last_motion=True,\n            enable_tsm=False,\n            trainable_token_pos_emb=False,\n            motion_token_num=1024,\n            enable_framepack=False,  # Mutually exclusive with enable_motioner\n            framepack_drop_mode=\"drop\",\n            model_type='s2v',\n            patch_size=(1, 2, 2),\n            text_len=512,\n            in_dim=16,\n            dim=2048,\n            ffn_dim=8192,\n            freq_dim=256,\n            text_dim=4096,\n            out_dim=16,\n            num_heads=16,\n            num_layers=32,\n            window_size=(-1, -1),\n            qk_norm=True,\n            cross_attn_norm=True,\n            eps=1e-6,\n            *args,\n            **kwargs):\n        super().__init__()\n\n        assert model_type == 's2v'\n        self.model_type = model_type\n\n        self.patch_size = patch_size\n        self.text_len = text_len\n        self.in_dim = in_dim\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.freq_dim = freq_dim\n        self.text_dim = text_dim\n        self.out_dim = out_dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n\n        # embeddings\n        self.patch_embedding = nn.Conv3d(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size)\n        self.text_embedding = nn.Sequential(\n            nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),\n            nn.Linear(dim, dim))\n\n        self.time_embedding = nn.Sequential(\n            nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))\n\n        # blocks\n        self.blocks = nn.ModuleList([\n            WanS2VAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,\n                                 cross_attn_norm, eps)\n            for _ in range(num_layers)\n        ])\n\n        # head\n        self.head = Head_S2V(dim, out_dim, patch_size, eps)\n\n        # buffers (don't use register_buffer otherwise dtype will be changed in to())\n        assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0\n        d = dim // num_heads\n        self.freqs = torch.cat([\n            rope_params(1024, d - 4 * (d // 6)),\n            rope_params(1024, 2 * (d // 6)),\n            rope_params(1024, 2 * (d // 6))\n        ],\n                               dim=1)\n\n        # initialize weights\n        self.init_weights()\n\n        self.use_context_parallel = False  # will modify in _configure_model func\n\n        if cond_dim > 0:\n            self.cond_encoder = nn.Conv3d(\n                cond_dim,\n                self.dim,\n                kernel_size=self.patch_size,\n                stride=self.patch_size)\n        self.enbale_adain = enable_adain\n        self.casual_audio_encoder = CausalAudioEncoder(\n            dim=audio_dim,\n            out_dim=self.dim,\n            num_token=num_audio_token,\n            need_global=enable_adain)\n        all_modules, all_modules_names = torch_dfs(\n            self.blocks, parent_name=\"root.transformer_blocks\")\n        self.audio_injector = AudioInjector_WAN(\n            all_modules,\n            all_modules_names,\n            dim=self.dim,\n            num_heads=self.num_heads,\n            inject_layer=audio_inject_layers,\n            root_net=self,\n            enable_adain=enable_adain,\n            adain_dim=self.dim,\n            need_adain_ont=adain_mode != \"attn_norm\",\n        )\n        self.adain_mode = adain_mode\n\n        self.trainable_cond_mask = nn.Embedding(3, self.dim)\n\n        if zero_init:\n            self.zero_init_weights()\n\n        self.zero_timestep = zero_timestep  # Whether to assign 0 value timestep to ref/motion\n\n        # init motioner\n        if enable_motioner and enable_framepack:\n            raise ValueError(\n                \"enable_motioner and enable_framepack are mutually exclusive, please set one of them to False\"\n            )\n        self.enable_motioner = enable_motioner\n        self.add_last_motion = add_last_motion\n        if enable_motioner:\n            motioner_dim = 2048\n            self.motioner = MotionerTransformers(\n                patch_size=(2, 4, 4),\n                dim=motioner_dim,\n                ffn_dim=motioner_dim,\n                freq_dim=256,\n                out_dim=16,\n                num_heads=16,\n                num_layers=13,\n                window_size=(-1, -1),\n                qk_norm=True,\n                cross_attn_norm=False,\n                eps=1e-6,\n                motion_token_num=motion_token_num,\n                enable_tsm=enable_tsm,\n                motion_stride=4,\n                expand_ratio=2,\n                trainable_token_pos_emb=trainable_token_pos_emb,\n            )\n            self.zip_motion_out = torch.nn.Sequential(\n                WanLayerNorm(motioner_dim),\n                zero_module(nn.Linear(motioner_dim, self.dim)))\n\n            self.trainable_token_pos_emb = trainable_token_pos_emb\n            if trainable_token_pos_emb:\n                d = self.dim // self.num_heads\n                x = torch.zeros([1, motion_token_num, self.num_heads, d])\n                x[..., ::2] = 1\n\n                gride_sizes = [[\n                    torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),\n                    torch.tensor([\n                        1, self.motioner.motion_side_len,\n                        self.motioner.motion_side_len\n                    ]).unsqueeze(0).repeat(1, 1),\n                    torch.tensor([\n                        1, self.motioner.motion_side_len,\n                        self.motioner.motion_side_len\n                    ]).unsqueeze(0).repeat(1, 1),\n                ]]\n                token_freqs = rope_apply(x, gride_sizes, self.freqs)\n                token_freqs = token_freqs[0, :,\n                                          0].reshape(motion_token_num, -1, 2)\n                token_freqs = token_freqs * 0.01\n                self.token_freqs = torch.nn.Parameter(token_freqs)\n\n        self.enable_framepack = enable_framepack\n        if enable_framepack:\n            self.frame_packer = FramePackMotioner(\n                inner_dim=self.dim,\n                num_heads=self.num_heads,\n                zip_frame_buckets=[1, 2, 16],\n                drop_mode=framepack_drop_mode)\n\n    def zero_init_weights(self):\n        with torch.no_grad():\n            self.trainable_cond_mask = zero_module(self.trainable_cond_mask)\n            if hasattr(self, \"cond_encoder\"):\n                self.cond_encoder = zero_module(self.cond_encoder)\n\n            for i in range(self.audio_injector.injector.__len__()):\n                self.audio_injector.injector[i].o = zero_module(\n                    self.audio_injector.injector[i].o)\n                if self.enbale_adain:\n                    self.audio_injector.injector_adain_layers[\n                        i].linear = zero_module(\n                            self.audio_injector.injector_adain_layers[i].linear)\n\n    def process_motion(self, motion_latents, drop_motion_frames=False):\n        if drop_motion_frames or motion_latents[0].shape[1] == 0:\n            return [], []\n        self.lat_motion_frames = motion_latents[0].shape[1]\n        mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]\n        batch_size = len(mot)\n\n        mot_remb = []\n        flattern_mot = []\n        for bs in range(batch_size):\n            height, width = mot[bs].shape[3], mot[bs].shape[4]\n            flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()\n            motion_grid_sizes = [[\n                torch.tensor([-self.lat_motion_frames, 0,\n                              0]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([self.lat_motion_frames, height,\n                              width]).unsqueeze(0).repeat(1, 1)\n            ]]\n            motion_rope_emb = rope_precompute(\n                flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,\n                                       self.dim // self.num_heads),\n                motion_grid_sizes,\n                self.freqs,\n                start=None)\n            mot_remb.append(motion_rope_emb)\n            flattern_mot.append(flat_mot)\n        return flattern_mot, mot_remb\n\n    def process_motion_frame_pack(self,\n                                  motion_latents,\n                                  drop_motion_frames=False,\n                                  add_last_motion=2):\n        flattern_mot, mot_remb = self.frame_packer(motion_latents,\n                                                   add_last_motion)\n        if drop_motion_frames:\n            return [m[:, :0] for m in flattern_mot\n                   ], [m[:, :0] for m in mot_remb]\n        else:\n            return flattern_mot, mot_remb\n\n    def process_motion_transformer_motioner(self,\n                                            motion_latents,\n                                            drop_motion_frames=False,\n                                            add_last_motion=True):\n        batch_size, height, width = len(\n            motion_latents), motion_latents[0].shape[2] // self.patch_size[\n                1], motion_latents[0].shape[3] // self.patch_size[2]\n\n        freqs = self.freqs\n        device = self.patch_embedding.weight.device\n        if freqs.device != device:\n            freqs = freqs.to(device)\n        if self.trainable_token_pos_emb:\n            with amp.autocast(dtype=torch.float64):\n                token_freqs = self.token_freqs.to(torch.float64)\n                token_freqs = token_freqs / token_freqs.norm(\n                    dim=-1, keepdim=True)\n                freqs = [freqs, torch.view_as_complex(token_freqs)]\n\n        if not drop_motion_frames and add_last_motion:\n            last_motion_latent = [u[:, -1:] for u in motion_latents]\n            last_mot = [\n                self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent\n            ]\n            last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]\n            last_mot = torch.cat(last_mot)\n            gride_sizes = [[\n                torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),\n                torch.tensor([0, height,\n                              width]).unsqueeze(0).repeat(batch_size, 1),\n                torch.tensor([1, height,\n                              width]).unsqueeze(0).repeat(batch_size, 1)\n            ]]\n        else:\n            last_mot = torch.zeros([batch_size, 0, self.dim],\n                                   device=motion_latents[0].device,\n                                   dtype=motion_latents[0].dtype)\n            gride_sizes = []\n\n        zip_motion = self.motioner(motion_latents)\n        zip_motion = self.zip_motion_out(zip_motion)\n        if drop_motion_frames:\n            zip_motion = zip_motion * 0.0\n        zip_motion_grid_sizes = [[\n            torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),\n            torch.tensor([\n                0, self.motioner.motion_side_len, self.motioner.motion_side_len\n            ]).unsqueeze(0).repeat(batch_size, 1),\n            torch.tensor(\n                [1 if not self.trainable_token_pos_emb else -1, height,\n                 width]).unsqueeze(0).repeat(batch_size, 1),\n        ]]\n\n        mot = torch.cat([last_mot, zip_motion], dim=1)\n        gride_sizes = gride_sizes + zip_motion_grid_sizes\n\n        motion_rope_emb = rope_precompute(\n            mot.detach().view(batch_size, mot.shape[1], self.num_heads,\n                              self.dim // self.num_heads),\n            gride_sizes,\n            freqs,\n            start=None)\n        return [m.unsqueeze(0) for m in mot\n               ], [r.unsqueeze(0) for r in motion_rope_emb]\n\n    def inject_motion(self,\n                      x,\n                      seq_lens,\n                      rope_embs,\n                      mask_input,\n                      motion_latents,\n                      drop_motion_frames=False,\n                      add_last_motion=True):\n        # inject the motion frames token to the hidden states\n        if self.enable_motioner:\n            mot, mot_remb = self.process_motion_transformer_motioner(\n                motion_latents,\n                drop_motion_frames=drop_motion_frames,\n                add_last_motion=add_last_motion)\n        elif self.enable_framepack:\n            mot, mot_remb = self.process_motion_frame_pack(\n                motion_latents,\n                drop_motion_frames=drop_motion_frames,\n                add_last_motion=add_last_motion)\n        else:\n            mot, mot_remb = self.process_motion(\n                motion_latents, drop_motion_frames=drop_motion_frames)\n\n        if len(mot) > 0:\n            x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]\n            seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],\n                                               dtype=torch.long)\n            rope_embs = [\n                torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)\n            ]\n            mask_input = [\n                torch.cat([\n                    m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],\n                                      device=m.device,\n                                      dtype=m.dtype)\n                ],\n                          dim=1) for m, u in zip(mask_input, x)\n            ]\n        return x, seq_lens, rope_embs, mask_input\n\n    def after_transformer_block(self, block_idx, hidden_states):\n        if block_idx in self.audio_injector.injected_block_id.keys():\n            audio_attn_id = self.audio_injector.injected_block_id[block_idx]\n            audio_emb = self.merged_audio_emb  # b f n c\n            num_frames = audio_emb.shape[1]\n\n            if self.use_context_parallel:\n                hidden_states = gather_forward(hidden_states, dim=1)\n\n            input_hidden_states = hidden_states[:, :self.\n                                                original_seq_len].clone(\n                                                )  # b (f h w) c\n            input_hidden_states = rearrange(\n                input_hidden_states, \"b (t n) c -> (b t) n c\", t=num_frames)\n\n            if self.enbale_adain and self.adain_mode == \"attn_norm\":\n                audio_emb_global = self.audio_emb_global\n                audio_emb_global = rearrange(audio_emb_global,\n                                             \"b t n c -> (b t) n c\")\n                adain_hidden_states = self.audio_injector.injector_adain_layers[\n                    audio_attn_id](\n                        input_hidden_states, temb=audio_emb_global[:, 0])\n                attn_hidden_states = adain_hidden_states\n            else:\n                attn_hidden_states = self.audio_injector.injector_pre_norm_feat[\n                    audio_attn_id](\n                        input_hidden_states)\n            audio_emb = rearrange(\n                audio_emb, \"b t n c -> (b t) n c\", t=num_frames)\n            attn_audio_emb = audio_emb\n            residual_out = self.audio_injector.injector[audio_attn_id](\n                x=attn_hidden_states,\n                context=attn_audio_emb,\n                context_lens=torch.ones(\n                    attn_hidden_states.shape[0],\n                    dtype=torch.long,\n                    device=attn_hidden_states.device) * attn_audio_emb.shape[1])\n            residual_out = rearrange(\n                residual_out, \"(b t) n c -> b (t n) c\", t=num_frames)\n            hidden_states[:, :self.\n                          original_seq_len] = hidden_states[:, :self.\n                                                            original_seq_len] + residual_out\n\n            if self.use_context_parallel:\n                hidden_states = torch.chunk(\n                    hidden_states, get_world_size(), dim=1)[get_rank()]\n\n        return hidden_states\n\n    def forward(\n            self,\n            x,\n            t,\n            context,\n            seq_len,\n            ref_latents,\n            motion_latents,\n            cond_states,\n            audio_input=None,\n            motion_frames=[17, 5],\n            add_last_motion=2,\n            drop_motion_frames=False,\n            *extra_args,\n            **extra_kwargs):\n        \"\"\"\n        x:                  A list of videos each with shape [C, T, H, W].\n        t:                  [B].\n        context:            A list of text embeddings each with shape [L, C].\n        seq_len:            A list of video token lens, no need for this model.\n        ref_latents         A list of reference image for each video with shape [C, 1, H, W].\n        motion_latents      A list of  motion frames for each video with shape [C, T_m, H, W].\n        cond_states         A list of condition frames (i.e. pose) each with shape [C, T, H, W].\n        audio_input         The input audio embedding [B, num_wav2vec_layer, C_a, T_a].\n        motion_frames       The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]\n        add_last_motion     For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.\n                            For frame packing, the behavior depends on the value of add_last_motion:\n                            add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.\n                            add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.\n                            add_last_motion = 2: All motion-related latents are used.\n        drop_motion_frames  Bool, whether drop the motion frames info\n        \"\"\"\n        add_last_motion = self.add_last_motion * add_last_motion\n        audio_input = torch.cat([\n            audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input\n        ],\n                                dim=-1)\n        audio_emb_res = self.casual_audio_encoder(audio_input)\n        if self.enbale_adain:\n            audio_emb_global, audio_emb = audio_emb_res\n            self.audio_emb_global = audio_emb_global[:,\n                                                     motion_frames[1]:].clone()\n        else:\n            audio_emb = audio_emb_res\n        self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :]\n\n        device = self.patch_embedding.weight.device\n\n        # embeddings\n        x = [self.patch_embedding(u.unsqueeze(0)) for u in x]\n        # cond states\n        cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]\n        x = [x_ + pose for x_, pose in zip(x, cond)]\n\n        grid_sizes = torch.stack(\n            [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])\n        x = [u.flatten(2).transpose(1, 2) for u in x]\n        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)\n\n        original_grid_sizes = deepcopy(grid_sizes)\n        grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]\n\n        # ref and motion\n        self.lat_motion_frames = motion_latents[0].shape[1]\n\n        ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]\n        batch_size = len(ref)\n        height, width = ref[0].shape[3], ref[0].shape[4]\n        ref_grid_sizes = [[\n            torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size,\n                                                         1),  # the start index\n            torch.tensor([31, height,\n                          width]).unsqueeze(0).repeat(batch_size,\n                                                      1),  # the end index\n            torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),\n        ]  # the range\n                         ]\n\n        ref = [r.flatten(2).transpose(1, 2) for r in ref]  # r: 1 c f h w\n        self.original_seq_len = seq_lens[0]\n\n        seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref],\n                                           dtype=torch.long)\n\n        grid_sizes = grid_sizes + ref_grid_sizes\n\n        x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]\n\n        # Initialize masks to indicate noisy latent, ref latent, and motion latent.\n        # However, at this point, only the first two (noisy and ref latents) are marked;\n        # the marking of motion latent will be implemented inside `inject_motion`.\n        mask_input = [\n            torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)\n            for u in x\n        ]\n        for i in range(len(mask_input)):\n            mask_input[i][:, self.original_seq_len:] = 1\n\n        # compute the rope embeddings for the input\n        x = torch.cat(x)\n        b, s, n, d = x.size(0), x.size(\n            1), self.num_heads, self.dim // self.num_heads\n        self.pre_compute_freqs = rope_precompute(\n            x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)\n\n        x = [u.unsqueeze(0) for u in x]\n        self.pre_compute_freqs = [\n            u.unsqueeze(0) for u in self.pre_compute_freqs\n        ]\n\n        x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(\n            x,\n            seq_lens,\n            self.pre_compute_freqs,\n            mask_input,\n            motion_latents,\n            drop_motion_frames=drop_motion_frames,\n            add_last_motion=add_last_motion)\n\n        x = torch.cat(x, dim=0)\n        self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)\n        mask_input = torch.cat(mask_input, dim=0)\n\n        x = x + self.trainable_cond_mask(mask_input).to(x.dtype)\n\n        # time embeddings\n        if self.zero_timestep:\n            t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])\n        with amp.autocast(dtype=torch.float32):\n            e = self.time_embedding(\n                sinusoidal_embedding_1d(self.freq_dim, t).float())\n            e0 = self.time_projection(e).unflatten(1, (6, self.dim))\n            assert e.dtype == torch.float32 and e0.dtype == torch.float32\n\n        if self.zero_timestep:\n            e = e[:-1]\n            zero_e0 = e0[-1:]\n            e0 = e0[:-1]\n            token_len = x.shape[1]\n            e0 = torch.cat([\n                e0.unsqueeze(2),\n                zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)\n            ],\n                           dim=2)\n            e0 = [e0, self.original_seq_len]\n        else:\n            e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)\n            e0 = [e0, 0]\n\n        # context\n        context_lens = None\n        context = self.text_embedding(\n            torch.stack([\n                torch.cat(\n                    [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])\n                for u in context\n            ]))\n\n        # grad ckpt args\n        def create_custom_forward(module, return_dict=None):\n\n            def custom_forward(*inputs, **kwargs):\n                if return_dict is not None:\n                    return module(*inputs, **kwargs, return_dict=return_dict)\n                else:\n                    return module(*inputs, **kwargs)\n\n            return custom_forward\n\n        if self.use_context_parallel:\n            # sharded tensors for long context attn\n            sp_rank = get_rank()\n            x = torch.chunk(x, get_world_size(), dim=1)\n            sq_size = [u.shape[1] for u in x]\n            sq_start_size = sum(sq_size[:sp_rank])\n            x = x[sp_rank]\n            # Confirm the application range of the time embedding in e0[0] for each sequence:\n            # - For tokens before seg_id: apply e0[0][:, :, 0]\n            # - For tokens after seg_id: apply e0[0][:, :, 1]\n            sp_size = x.shape[1]\n            seg_idx = e0[1] - sq_start_size\n            e0[1] = seg_idx\n\n            self.pre_compute_freqs = torch.chunk(\n                self.pre_compute_freqs, get_world_size(), dim=1)\n            self.pre_compute_freqs = self.pre_compute_freqs[sp_rank]\n\n        # arguments\n        kwargs = dict(\n            e=e0,\n            seq_lens=seq_lens,\n            grid_sizes=grid_sizes,\n            freqs=self.pre_compute_freqs,\n            context=context,\n            context_lens=context_lens)\n        for idx, block in enumerate(self.blocks):\n            x = block(x, **kwargs)\n            x = self.after_transformer_block(idx, x)\n\n        # Context Parallel\n        if self.use_context_parallel:\n            x = gather_forward(x.contiguous(), dim=1)\n        # unpatchify\n        x = x[:, :self.original_seq_len]\n        # head\n        x = self.head(x, e)\n        x = self.unpatchify(x, original_grid_sizes)\n        return [u.float() for u in x]\n\n    def unpatchify(self, x, grid_sizes):\n        \"\"\"\n        Reconstruct video tensors from patch embeddings.\n\n        Args:\n            x (List[Tensor]):\n                List of patchified features, each with shape [L, C_out * prod(patch_size)]\n            grid_sizes (Tensor):\n                Original spatial-temporal grid dimensions before patching,\n                    shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)\n\n        Returns:\n            List[Tensor]:\n                Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]\n        \"\"\"\n\n        c = self.out_dim\n        out = []\n        for u, v in zip(x, grid_sizes.tolist()):\n            u = u[:math.prod(v)].view(*v, *self.patch_size, c)\n            u = torch.einsum('fhwpqrc->cfphqwr', u)\n            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])\n            out.append(u)\n        return out\n\n    def init_weights(self):\n        r\"\"\"\n        Initialize model parameters using Xavier initialization.\n        \"\"\"\n\n        # basic init\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.xavier_uniform_(m.weight)\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n\n        # init embeddings\n        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))\n        for m in self.text_embedding.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=.02)\n        for m in self.time_embedding.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, std=.02)\n\n        # init output layer\n        nn.init.zeros_(self.head.head.weight)\n"
  },
  {
    "path": "wan/modules/s2v/motioner.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nfrom typing import Any, Dict, List, Literal, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn as nn\nfrom diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin\nfrom diffusers.utils import BaseOutput, is_torch_version\nfrom einops import rearrange, repeat\n\nfrom ..model import flash_attention\nfrom .s2v_utils import rope_precompute\n\n\ndef sinusoidal_embedding_1d(dim, position):\n    # preprocess\n    assert dim % 2 == 0\n    half = dim // 2\n    position = position.type(torch.float64)\n\n    # calculation\n    sinusoid = torch.outer(\n        position, torch.pow(10000, -torch.arange(half).to(position).div(half)))\n    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)\n    return x\n\n\n@amp.autocast(enabled=False)\ndef rope_params(max_seq_len, dim, theta=10000):\n    assert dim % 2 == 0\n    freqs = torch.outer(\n        torch.arange(max_seq_len),\n        1.0 / torch.pow(theta,\n                        torch.arange(0, dim, 2).to(torch.float64).div(dim)))\n    freqs = torch.polar(torch.ones_like(freqs), freqs)\n    return freqs\n\n\n@amp.autocast(enabled=False)\ndef rope_apply(x, grid_sizes, freqs, start=None):\n    n, c = x.size(2), x.size(3) // 2\n\n    # split freqs\n    if type(freqs) is list:\n        trainable_freqs = freqs[1]\n        freqs = freqs[0]\n    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)\n\n    # loop over samples\n    output = []\n    output = x.clone()\n    seq_bucket = [0]\n    if not type(grid_sizes) is list:\n        grid_sizes = [grid_sizes]\n    for g in grid_sizes:\n        if not type(g) is list:\n            g = [torch.zeros_like(g), g]\n        batch_size = g[0].shape[0]\n        for i in range(batch_size):\n            if start is None:\n                f_o, h_o, w_o = g[0][i]\n            else:\n                f_o, h_o, w_o = start[i]\n\n            f, h, w = g[1][i]\n            t_f, t_h, t_w = g[2][i]\n            seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o\n            seq_len = int(seq_f * seq_h * seq_w)\n            if seq_len > 0:\n                if t_f > 0:\n                    factor_f, factor_h, factor_w = (t_f / seq_f).item(), (\n                        t_h / seq_h).item(), (t_w / seq_w).item()\n\n                    if f_o >= 0:\n                        f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,\n                                            seq_f).astype(int).tolist()\n                    else:\n                        f_sam = np.linspace(-f_o.item(),\n                                            (-t_f - f_o).item() + 1,\n                                            seq_f).astype(int).tolist()\n                    h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,\n                                        seq_h).astype(int).tolist()\n                    w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,\n                                        seq_w).astype(int).tolist()\n\n                    assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0\n                    freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][\n                        f_sam].conj()\n                    freqs_0 = freqs_0.view(seq_f, 1, 1, -1)\n\n                    freqs_i = torch.cat([\n                        freqs_0.expand(seq_f, seq_h, seq_w, -1),\n                        freqs[1][h_sam].view(1, seq_h, 1, -1).expand(\n                            seq_f, seq_h, seq_w, -1),\n                        freqs[2][w_sam].view(1, 1, seq_w, -1).expand(\n                            seq_f, seq_h, seq_w, -1),\n                    ],\n                                        dim=-1).reshape(seq_len, 1, -1)\n                elif t_f < 0:\n                    freqs_i = trainable_freqs.unsqueeze(1)\n                # apply rotary embedding\n                # precompute multipliers\n                x_i = torch.view_as_complex(\n                    x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to(\n                        torch.float64).reshape(seq_len, n, -1, 2))\n                x_i = torch.view_as_real(x_i * freqs_i).flatten(2)\n                output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i\n        seq_bucket.append(seq_bucket[-1] + seq_len)\n    return output.float()\n\n\nclass RMSNorm(nn.Module):\n\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        return self._norm(x.float()).type_as(x) * self.weight\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)\n\n\nclass LayerNorm(nn.LayerNorm):\n\n    def __init__(self, dim, eps=1e-6, elementwise_affine=False):\n        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)\n\n    def forward(self, x):\n        return super().forward(x.float()).type_as(x)\n\n\nclass SelfAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 num_heads,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 eps=1e-6):\n        assert dim % num_heads == 0\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.eps = eps\n\n        # layers\n        self.q = nn.Linear(dim, dim)\n        self.k = nn.Linear(dim, dim)\n        self.v = nn.Linear(dim, dim)\n        self.o = nn.Linear(dim, dim)\n        self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n        self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()\n\n    def forward(self, x, seq_lens, grid_sizes, freqs):\n        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n\n        # query, key, value function\n        def qkv_fn(x):\n            q = self.norm_q(self.q(x)).view(b, s, n, d)\n            k = self.norm_k(self.k(x)).view(b, s, n, d)\n            v = self.v(x).view(b, s, n, d)\n            return q, k, v\n\n        q, k, v = qkv_fn(x)\n\n        x = flash_attention(\n            q=rope_apply(q, grid_sizes, freqs),\n            k=rope_apply(k, grid_sizes, freqs),\n            v=v,\n            k_lens=seq_lens,\n            window_size=self.window_size)\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\nclass SwinSelfAttention(SelfAttention):\n\n    def forward(self, x, seq_lens, grid_sizes, freqs):\n        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n        assert b == 1, 'Only support batch_size 1'\n\n        # query, key, value function\n        def qkv_fn(x):\n            q = self.norm_q(self.q(x)).view(b, s, n, d)\n            k = self.norm_k(self.k(x)).view(b, s, n, d)\n            v = self.v(x).view(b, s, n, d)\n            return q, k, v\n\n        q, k, v = qkv_fn(x)\n\n        q = rope_apply(q, grid_sizes, freqs)\n        k = rope_apply(k, grid_sizes, freqs)\n        T, H, W = grid_sizes[0].tolist()\n\n        q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)\n        k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)\n        v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)\n\n        ref_q = q[-1:]\n        q = q[:-1]\n\n        ref_k = repeat(\n            k[-1:], \"1 s n d -> t s n d\", t=k.shape[0] - 1)  # t hw n d\n        k = k[:-1]\n        k = torch.cat([k[:1], k, k[-1:]])\n        k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1)  # (bt) (3hw) n d\n\n        ref_v = repeat(v[-1:], \"1 s n d -> t s n d\", t=v.shape[0] - 1)\n        v = v[:-1]\n        v = torch.cat([v[:1], v, v[-1:]])\n        v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)\n\n        # q: b (t h w) n d\n        # k: b (t h w) n d\n        out = flash_attention(\n            q=q,\n            k=k,\n            v=v,\n            # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),\n            window_size=self.window_size)\n        out = torch.cat([out, ref_v[:1]], axis=0)\n        out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)\n        x = out\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\n#Fix the reference frame RoPE to 1,H,W.\n#Set the current frame RoPE to 1.\n#Set the previous frame RoPE to 0.\nclass CasualSelfAttention(SelfAttention):\n\n    def forward(self, x, seq_lens, grid_sizes, freqs):\n        shifting = 3\n        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim\n        assert b == 1, 'Only support batch_size 1'\n\n        # query, key, value function\n        def qkv_fn(x):\n            q = self.norm_q(self.q(x)).view(b, s, n, d)\n            k = self.norm_k(self.k(x)).view(b, s, n, d)\n            v = self.v(x).view(b, s, n, d)\n            return q, k, v\n\n        q, k, v = qkv_fn(x)\n\n        T, H, W = grid_sizes[0].tolist()\n\n        q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)\n        k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)\n        v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)\n\n        ref_q = q[-1:]\n        q = q[:-1]\n\n        grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long)\n        start = [[shifting, 0, 0]] * q.shape[0]\n        q = rope_apply(q, grid_sizes, freqs, start=start)\n\n        ref_k = k[-1:]\n        grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long)\n        # start = [[shifting, H, W]]\n\n        start = [[shifting + 10, 0, 0]]\n        ref_k = rope_apply(ref_k, grid_sizes, freqs, start)\n        ref_k = repeat(\n            ref_k, \"1 s n d -> t s n d\", t=k.shape[0] - 1)  # t hw n d\n\n        k = k[:-1]\n        k = torch.cat([*([k[:1]] * shifting), k])\n        cat_k = []\n        for i in range(shifting):\n            cat_k.append(k[i:i - shifting])\n        cat_k.append(k[shifting:])\n        k = torch.cat(cat_k, dim=1)  # (bt) (3hw) n d\n\n        grid_sizes = torch.tensor(\n            [[shifting + 1, H, W]] * q.shape[0], dtype=torch.long)\n        k = rope_apply(k, grid_sizes, freqs)\n        k = torch.cat([k, ref_k], dim=1)\n\n        ref_v = repeat(v[-1:], \"1 s n d -> t s n d\", t=q.shape[0])  # t hw n d\n        v = v[:-1]\n        v = torch.cat([*([v[:1]] * shifting), v])\n        cat_v = []\n        for i in range(shifting):\n            cat_v.append(v[i:i - shifting])\n        cat_v.append(v[shifting:])\n        v = torch.cat(cat_v, dim=1)  # (bt) (3hw) n d\n        v = torch.cat([v, ref_v], dim=1)\n\n        # q: b (t h w) n d\n        # k: b (t h w) n d\n        outs = []\n        for i in range(q.shape[0]):\n            out = flash_attention(\n                q=q[i:i + 1],\n                k=k[i:i + 1],\n                v=v[i:i + 1],\n                window_size=self.window_size)\n            outs.append(out)\n        out = torch.cat(outs, dim=0)\n        out = torch.cat([out, ref_v[:1]], axis=0)\n        out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)\n        x = out\n\n        # output\n        x = x.flatten(2)\n        x = self.o(x)\n        return x\n\n\nclass MotionerAttentionBlock(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 ffn_dim,\n                 num_heads,\n                 window_size=(-1, -1),\n                 qk_norm=True,\n                 cross_attn_norm=False,\n                 eps=1e-6,\n                 self_attn_block=\"SelfAttention\"):\n        super().__init__()\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n\n        # layers\n        self.norm1 = LayerNorm(dim, eps)\n        if self_attn_block == \"SelfAttention\":\n            self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm,\n                                           eps)\n        elif self_attn_block == \"SwinSelfAttention\":\n            self.self_attn = SwinSelfAttention(dim, num_heads, window_size,\n                                               qk_norm, eps)\n        elif self_attn_block == \"CasualSelfAttention\":\n            self.self_attn = CasualSelfAttention(dim, num_heads, window_size,\n                                                 qk_norm, eps)\n\n        self.norm2 = LayerNorm(dim, eps)\n        self.ffn = nn.Sequential(\n            nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),\n            nn.Linear(ffn_dim, dim))\n\n    def forward(\n        self,\n        x,\n        seq_lens,\n        grid_sizes,\n        freqs,\n    ):\n        # self-attention\n        y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs)\n        x = x + y\n        y = self.ffn(self.norm2(x).float())\n        x = x + y\n        return x\n\n\nclass Head(nn.Module):\n\n    def __init__(self, dim, out_dim, patch_size, eps=1e-6):\n        super().__init__()\n        self.dim = dim\n        self.out_dim = out_dim\n        self.patch_size = patch_size\n        self.eps = eps\n\n        # layers\n        out_dim = math.prod(patch_size) * out_dim\n        self.norm = LayerNorm(dim, eps)\n        self.head = nn.Linear(dim, out_dim)\n\n    def forward(self, x):\n        x = self.head(self.norm(x))\n        return x\n\n\nclass MotionerTransformers(nn.Module, PeftAdapterMixin):\n\n    def __init__(\n        self,\n        patch_size=(1, 2, 2),\n        in_dim=16,\n        dim=2048,\n        ffn_dim=8192,\n        freq_dim=256,\n        out_dim=16,\n        num_heads=16,\n        num_layers=32,\n        window_size=(-1, -1),\n        qk_norm=True,\n        cross_attn_norm=False,\n        eps=1e-6,\n        self_attn_block=\"SelfAttention\",\n        motion_token_num=1024,\n        enable_tsm=False,\n        motion_stride=4,\n        expand_ratio=2,\n        trainable_token_pos_emb=False,\n    ):\n        super().__init__()\n        self.patch_size = patch_size\n        self.in_dim = in_dim\n        self.dim = dim\n        self.ffn_dim = ffn_dim\n        self.freq_dim = freq_dim\n        self.out_dim = out_dim\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.window_size = window_size\n        self.qk_norm = qk_norm\n        self.cross_attn_norm = cross_attn_norm\n        self.eps = eps\n\n        self.enable_tsm = enable_tsm\n        self.motion_stride = motion_stride\n        self.expand_ratio = expand_ratio\n        self.sample_c = self.patch_size[0]\n\n        # embeddings\n        self.patch_embedding = nn.Conv3d(\n            in_dim, dim, kernel_size=patch_size, stride=patch_size)\n\n        # blocks\n        self.blocks = nn.ModuleList([\n            MotionerAttentionBlock(\n                dim,\n                ffn_dim,\n                num_heads,\n                window_size,\n                qk_norm,\n                cross_attn_norm,\n                eps,\n                self_attn_block=self_attn_block) for _ in range(num_layers)\n        ])\n\n        # buffers (don't use register_buffer otherwise dtype will be changed in to())\n        assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0\n        d = dim // num_heads\n        self.freqs = torch.cat([\n            rope_params(1024, d - 4 * (d // 6)),\n            rope_params(1024, 2 * (d // 6)),\n            rope_params(1024, 2 * (d // 6))\n        ],\n                               dim=1)\n\n        self.gradient_checkpointing = False\n\n        self.motion_side_len = int(math.sqrt(motion_token_num))\n        assert self.motion_side_len**2 == motion_token_num\n        self.token = nn.Parameter(\n            torch.zeros(1, motion_token_num, dim).contiguous())\n\n        self.trainable_token_pos_emb = trainable_token_pos_emb\n        if trainable_token_pos_emb:\n            x = torch.zeros([1, motion_token_num, num_heads, d])\n            x[..., ::2] = 1\n\n            gride_sizes = [[\n                torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([1, self.motion_side_len,\n                              self.motion_side_len]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([1, self.motion_side_len,\n                              self.motion_side_len]).unsqueeze(0).repeat(1, 1),\n            ]]\n            token_freqs = rope_apply(x, gride_sizes, self.freqs)\n            token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2)\n            token_freqs = token_freqs * 0.01\n            self.token_freqs = torch.nn.Parameter(token_freqs)\n\n    def after_patch_embedding(self, x):\n        return x\n\n    def forward(\n        self,\n        x,\n    ):\n        \"\"\"\n        x:              A list of videos each with shape [C, T, H, W].\n        t:              [B].\n        context:        A list of text embeddings each with shape [L, C].\n        \"\"\"\n        # params\n        motion_frames = x[0].shape[1]\n        device = self.patch_embedding.weight.device\n        freqs = self.freqs\n        if freqs.device != device:\n            freqs = freqs.to(device)\n\n        if self.trainable_token_pos_emb:\n            with amp.autocast(dtype=torch.float64):\n                token_freqs = self.token_freqs.to(torch.float64)\n                token_freqs = token_freqs / token_freqs.norm(\n                    dim=-1, keepdim=True)\n                freqs = [freqs, torch.view_as_complex(token_freqs)]\n\n        if self.enable_tsm:\n            sample_idx = [\n                sample_indices(\n                    u.shape[1],\n                    stride=self.motion_stride,\n                    expand_ratio=self.expand_ratio,\n                    c=self.sample_c) for u in x\n            ]\n            x = [\n                torch.flip(torch.flip(u, [1])[:, idx], [1])\n                for idx, u in zip(sample_idx, x)\n            ]\n\n        # embeddings\n        x = [self.patch_embedding(u.unsqueeze(0)) for u in x]\n        x = self.after_patch_embedding(x)\n\n        seq_f, seq_h, seq_w = x[0].shape[-3:]\n        batch_size = len(x)\n        if not self.enable_tsm:\n            grid_sizes = torch.stack(\n                [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])\n            grid_sizes = [[\n                torch.zeros_like(grid_sizes), grid_sizes, grid_sizes\n            ]]\n            seq_f = 0\n        else:\n            grid_sizes = []\n            for idx in sample_idx[0][::-1][::self.sample_c]:\n                tsm_frame_grid_sizes = [[\n                    torch.tensor([idx, 0,\n                                  0]).unsqueeze(0).repeat(batch_size, 1),\n                    torch.tensor([idx + 1, seq_h,\n                                  seq_w]).unsqueeze(0).repeat(batch_size, 1),\n                    torch.tensor([1, seq_h,\n                                  seq_w]).unsqueeze(0).repeat(batch_size, 1),\n                ]]\n                grid_sizes += tsm_frame_grid_sizes\n            seq_f = sample_idx[0][-1] + 1\n\n        x = [u.flatten(2).transpose(1, 2) for u in x]\n        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)\n        x = torch.cat([u for u in x])\n\n        batch_size = len(x)\n\n        token_grid_sizes = [[\n            torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1),\n            torch.tensor(\n                [seq_f + 1, self.motion_side_len,\n                 self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1),\n            torch.tensor(\n                [1 if not self.trainable_token_pos_emb else -1, seq_h,\n                 seq_w]).unsqueeze(0).repeat(batch_size, 1),\n        ]  # 第三行代表rope emb的想要覆盖到的范围\n                           ]\n\n        grid_sizes = grid_sizes + token_grid_sizes\n        token_unpatch_grid_sizes = torch.stack([\n            torch.tensor([1, 32, 32], dtype=torch.long)\n            for b in range(batch_size)\n        ])\n        token_len = self.token.shape[1]\n        token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous()\n        seq_lens = seq_lens + torch.tensor([t.size(0) for t in token],\n                                           dtype=torch.long)\n        x = torch.cat([x, token], dim=1)\n        # arguments\n        kwargs = dict(\n            seq_lens=seq_lens,\n            grid_sizes=grid_sizes,\n            freqs=freqs,\n        )\n\n        # grad ckpt args\n        def create_custom_forward(module, return_dict=None):\n\n            def custom_forward(*inputs, **kwargs):\n                if return_dict is not None:\n                    return module(*inputs, **kwargs, return_dict=return_dict)\n                else:\n                    return module(*inputs, **kwargs)\n\n            return custom_forward\n\n        ckpt_kwargs: Dict[str, Any] = ({\n            \"use_reentrant\": False\n        } if is_torch_version(\">=\", \"1.11.0\") else {})\n\n        for idx, block in enumerate(self.blocks):\n            if self.training and self.gradient_checkpointing:\n                x = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    x,\n                    **kwargs,\n                    **ckpt_kwargs,\n                )\n            else:\n                x = block(x, **kwargs)\n        # head\n        out = x[:, -token_len:]\n        return out\n\n    def unpatchify(self, x, grid_sizes):\n        c = self.out_dim\n        out = []\n        for u, v in zip(x, grid_sizes.tolist()):\n            u = u[:math.prod(v)].view(*v, *self.patch_size, c)\n            u = torch.einsum('fhwpqrc->cfphqwr', u)\n            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])\n            out.append(u)\n        return out\n\n    def init_weights(self):\n        # basic init\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                nn.init.xavier_uniform_(m.weight)\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n\n        # init embeddings\n        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))\n\n\nclass FramePackMotioner(nn.Module):\n\n    def __init__(\n            self,\n            inner_dim=1024,\n            num_heads=16,  # Used to indicate the number of heads in the backbone network; unrelated to this module's design\n            zip_frame_buckets=[\n                1, 2, 16\n            ],  # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames\n            drop_mode=\"drop\",  # If not \"drop\", it will use \"padd\", meaning padding instead of deletion\n            *args,\n            **kwargs):\n        super().__init__(*args, **kwargs)\n        self.proj = nn.Conv3d(\n            16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))\n        self.proj_2x = nn.Conv3d(\n            16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))\n        self.proj_4x = nn.Conv3d(\n            16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))\n        self.zip_frame_buckets = torch.tensor(\n            zip_frame_buckets, dtype=torch.long)\n\n        self.inner_dim = inner_dim\n        self.num_heads = num_heads\n\n        assert (inner_dim %\n                num_heads) == 0 and (inner_dim // num_heads) % 2 == 0\n        d = inner_dim // num_heads\n        self.freqs = torch.cat([\n            rope_params(1024, d - 4 * (d // 6)),\n            rope_params(1024, 2 * (d // 6)),\n            rope_params(1024, 2 * (d // 6))\n        ],\n                               dim=1)\n        self.drop_mode = drop_mode\n\n    def forward(self, motion_latents, add_last_motion=2):\n        motion_frames = motion_latents[0].shape[1]\n        mot = []\n        mot_remb = []\n        for m in motion_latents:\n            lat_height, lat_width = m.shape[2], m.shape[3]\n            padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height,\n                                   lat_width).to(\n                                       device=m.device, dtype=m.dtype)\n            overlap_frame = min(padd_lat.shape[1], m.shape[1])\n            if overlap_frame > 0:\n                padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]\n\n            if add_last_motion < 2 and self.drop_mode != \"drop\":\n                zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.\n                                                        __len__() -\n                                                        add_last_motion -\n                                                        1].sum()\n                padd_lat[:, -zero_end_frame:] = 0\n\n            padd_lat = padd_lat.unsqueeze(0)\n            clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum(\n            ):, :, :].split(\n                list(self.zip_frame_buckets)[::-1], dim=2)  # 16, 2 ,1\n\n            # patchfy\n            clean_latents_post = self.proj(clean_latents_post).flatten(\n                2).transpose(1, 2)\n            clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(\n                2).transpose(1, 2)\n            clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(\n                2).transpose(1, 2)\n\n            if add_last_motion < 2 and self.drop_mode == \"drop\":\n                clean_latents_post = clean_latents_post[:, :\n                                                        0] if add_last_motion < 2 else clean_latents_post\n                clean_latents_2x = clean_latents_2x[:, :\n                                                    0] if add_last_motion < 1 else clean_latents_2x\n\n            motion_lat = torch.cat(\n                [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)\n\n            # rope\n            start_time_id = -(self.zip_frame_buckets[:1].sum())\n            end_time_id = start_time_id + self.zip_frame_buckets[0]\n            grid_sizes = [] if add_last_motion < 2 and self.drop_mode == \"drop\" else \\\n                        [\n                            [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),\n                            torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),\n                            torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]\n                        ]\n\n            start_time_id = -(self.zip_frame_buckets[:2].sum())\n            end_time_id = start_time_id + self.zip_frame_buckets[1] // 2\n            grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == \"drop\" else \\\n            [\n                [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]\n            ]\n\n            start_time_id = -(self.zip_frame_buckets[:3].sum())\n            end_time_id = start_time_id + self.zip_frame_buckets[2] // 4\n            grid_sizes_4x = [[\n                torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([end_time_id, lat_height // 8,\n                              lat_width // 8]).unsqueeze(0).repeat(1, 1),\n                torch.tensor([\n                    self.zip_frame_buckets[2], lat_height // 2, lat_width // 2\n                ]).unsqueeze(0).repeat(1, 1),\n            ]]\n\n            grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x\n\n            motion_rope_emb = rope_precompute(\n                motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads,\n                                         self.inner_dim // self.num_heads),\n                grid_sizes,\n                self.freqs,\n                start=None)\n\n            mot.append(motion_lat)\n            mot_remb.append(motion_rope_emb)\n        return mot, mot_remb\n\n\ndef sample_indices(N, stride, expand_ratio, c):\n    indices = []\n    current_start = 0\n\n    while current_start < N:\n        bucket_width = int(stride * (expand_ratio**(len(indices) / stride)))\n\n        interval = int(bucket_width / stride * c)\n        current_end = min(N, current_start + bucket_width)\n        bucket_samples = []\n        for i in range(current_end - 1, current_start - 1, -interval):\n            for near in range(c):\n                bucket_samples.append(i - near)\n\n        indices += bucket_samples[::-1]\n        current_start += bucket_width\n\n    return indices\n\n\nif __name__ == '__main__':\n    device = \"cuda\"\n    model = FramePackMotioner(inner_dim=1024)\n    batch_size = 2\n    num_frame, height, width = (28, 32, 32)\n    single_input = torch.ones([16, num_frame, height, width], device=device)\n    for i in range(num_frame):\n        single_input[:, num_frame - 1 - i] *= i\n    x = [single_input] * batch_size\n    model.forward(x)\n"
  },
  {
    "path": "wan/modules/s2v/s2v_utils.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport numpy as np\nimport torch\n\n\ndef rope_precompute(x, grid_sizes, freqs, start=None):\n    b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2\n\n    # split freqs\n    if type(freqs) is list:\n        trainable_freqs = freqs[1]\n        freqs = freqs[0]\n    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)\n\n    # loop over samples\n    output = torch.view_as_complex(x.detach().reshape(b, s, n, -1,\n                                                      2).to(torch.float64))\n    seq_bucket = [0]\n    if not type(grid_sizes) is list:\n        grid_sizes = [grid_sizes]\n    for g in grid_sizes:\n        if not type(g) is list:\n            g = [torch.zeros_like(g), g]\n        batch_size = g[0].shape[0]\n        for i in range(batch_size):\n            if start is None:\n                f_o, h_o, w_o = g[0][i]\n            else:\n                f_o, h_o, w_o = start[i]\n\n            f, h, w = g[1][i]\n            t_f, t_h, t_w = g[2][i]\n            seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o\n            seq_len = int(seq_f * seq_h * seq_w)\n            if seq_len > 0:\n                if t_f > 0:\n                    factor_f, factor_h, factor_w = (t_f / seq_f).item(), (\n                        t_h / seq_h).item(), (t_w / seq_w).item()\n                    # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())\n                    if f_o >= 0:\n                        f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,\n                                            seq_f).astype(int).tolist()\n                    else:\n                        f_sam = np.linspace(-f_o.item(),\n                                            (-t_f - f_o).item() + 1,\n                                            seq_f).astype(int).tolist()\n                    h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,\n                                        seq_h).astype(int).tolist()\n                    w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,\n                                        seq_w).astype(int).tolist()\n\n                    assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0\n                    freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][\n                        f_sam].conj()\n                    freqs_0 = freqs_0.view(seq_f, 1, 1, -1)\n\n                    freqs_i = torch.cat([\n                        freqs_0.expand(seq_f, seq_h, seq_w, -1),\n                        freqs[1][h_sam].view(1, seq_h, 1, -1).expand(\n                            seq_f, seq_h, seq_w, -1),\n                        freqs[2][w_sam].view(1, 1, seq_w, -1).expand(\n                            seq_f, seq_h, seq_w, -1),\n                    ],\n                                        dim=-1).reshape(seq_len, 1, -1)\n                elif t_f < 0:\n                    freqs_i = trainable_freqs.unsqueeze(1)\n                # apply rotary embedding\n                output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i\n        seq_bucket.append(seq_bucket[-1] + seq_len)\n    return output\n"
  },
  {
    "path": "wan/modules/t5.py",
    "content": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .tokenizers import HuggingfaceTokenizer\n\n__all__ = [\n    'T5Model',\n    'T5Encoder',\n    'T5Decoder',\n    'T5EncoderModel',\n]\n\n\ndef fp16_clamp(x):\n    if x.dtype == torch.float16 and torch.isinf(x).any():\n        clamp = torch.finfo(x.dtype).max - 1000\n        x = torch.clamp(x, min=-clamp, max=clamp)\n    return x\n\n\ndef init_weights(m):\n    if isinstance(m, T5LayerNorm):\n        nn.init.ones_(m.weight)\n    elif isinstance(m, T5Model):\n        nn.init.normal_(m.token_embedding.weight, std=1.0)\n    elif isinstance(m, T5FeedForward):\n        nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)\n        nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)\n        nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)\n    elif isinstance(m, T5Attention):\n        nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)\n        nn.init.normal_(m.k.weight, std=m.dim**-0.5)\n        nn.init.normal_(m.v.weight, std=m.dim**-0.5)\n        nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)\n    elif isinstance(m, T5RelativeEmbedding):\n        nn.init.normal_(\n            m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)\n\n\nclass GELU(nn.Module):\n\n    def forward(self, x):\n        return 0.5 * x * (1.0 + torch.tanh(\n            math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))\n\n\nclass T5LayerNorm(nn.Module):\n\n    def __init__(self, dim, eps=1e-6):\n        super(T5LayerNorm, self).__init__()\n        self.dim = dim\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +\n                            self.eps)\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            x = x.type_as(self.weight)\n        return self.weight * x\n\n\nclass T5Attention(nn.Module):\n\n    def __init__(self, dim, dim_attn, num_heads, dropout=0.1):\n        assert dim_attn % num_heads == 0\n        super(T5Attention, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.num_heads = num_heads\n        self.head_dim = dim_attn // num_heads\n\n        # layers\n        self.q = nn.Linear(dim, dim_attn, bias=False)\n        self.k = nn.Linear(dim, dim_attn, bias=False)\n        self.v = nn.Linear(dim, dim_attn, bias=False)\n        self.o = nn.Linear(dim_attn, dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x, context=None, mask=None, pos_bias=None):\n        \"\"\"\n        x:          [B, L1, C].\n        context:    [B, L2, C] or None.\n        mask:       [B, L2] or [B, L1, L2] or None.\n        \"\"\"\n        # check inputs\n        context = x if context is None else context\n        b, n, c = x.size(0), self.num_heads, self.head_dim\n\n        # compute query, key, value\n        q = self.q(x).view(b, -1, n, c)\n        k = self.k(context).view(b, -1, n, c)\n        v = self.v(context).view(b, -1, n, c)\n\n        # attention bias\n        attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))\n        if pos_bias is not None:\n            attn_bias += pos_bias\n        if mask is not None:\n            assert mask.ndim in [2, 3]\n            mask = mask.view(b, 1, 1,\n                             -1) if mask.ndim == 2 else mask.unsqueeze(1)\n            attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)\n\n        # compute attention (T5 does not use scaling)\n        attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias\n        attn = F.softmax(attn.float(), dim=-1).type_as(attn)\n        x = torch.einsum('bnij,bjnc->binc', attn, v)\n\n        # output\n        x = x.reshape(b, -1, n * c)\n        x = self.o(x)\n        x = self.dropout(x)\n        return x\n\n\nclass T5FeedForward(nn.Module):\n\n    def __init__(self, dim, dim_ffn, dropout=0.1):\n        super(T5FeedForward, self).__init__()\n        self.dim = dim\n        self.dim_ffn = dim_ffn\n\n        # layers\n        self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())\n        self.fc1 = nn.Linear(dim, dim_ffn, bias=False)\n        self.fc2 = nn.Linear(dim_ffn, dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = self.fc1(x) * self.gate(x)\n        x = self.dropout(x)\n        x = self.fc2(x)\n        x = self.dropout(x)\n        return x\n\n\nclass T5SelfAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 dim_attn,\n                 dim_ffn,\n                 num_heads,\n                 num_buckets,\n                 shared_pos=True,\n                 dropout=0.1):\n        super(T5SelfAttention, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.num_buckets = num_buckets\n        self.shared_pos = shared_pos\n\n        # layers\n        self.norm1 = T5LayerNorm(dim)\n        self.attn = T5Attention(dim, dim_attn, num_heads, dropout)\n        self.norm2 = T5LayerNorm(dim)\n        self.ffn = T5FeedForward(dim, dim_ffn, dropout)\n        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(\n            num_buckets, num_heads, bidirectional=True)\n\n    def forward(self, x, mask=None, pos_bias=None):\n        e = pos_bias if self.shared_pos else self.pos_embedding(\n            x.size(1), x.size(1))\n        x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))\n        x = fp16_clamp(x + self.ffn(self.norm2(x)))\n        return x\n\n\nclass T5CrossAttention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 dim_attn,\n                 dim_ffn,\n                 num_heads,\n                 num_buckets,\n                 shared_pos=True,\n                 dropout=0.1):\n        super(T5CrossAttention, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.num_buckets = num_buckets\n        self.shared_pos = shared_pos\n\n        # layers\n        self.norm1 = T5LayerNorm(dim)\n        self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)\n        self.norm2 = T5LayerNorm(dim)\n        self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)\n        self.norm3 = T5LayerNorm(dim)\n        self.ffn = T5FeedForward(dim, dim_ffn, dropout)\n        self.pos_embedding = None if shared_pos else T5RelativeEmbedding(\n            num_buckets, num_heads, bidirectional=False)\n\n    def forward(self,\n                x,\n                mask=None,\n                encoder_states=None,\n                encoder_mask=None,\n                pos_bias=None):\n        e = pos_bias if self.shared_pos else self.pos_embedding(\n            x.size(1), x.size(1))\n        x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))\n        x = fp16_clamp(x + self.cross_attn(\n            self.norm2(x), context=encoder_states, mask=encoder_mask))\n        x = fp16_clamp(x + self.ffn(self.norm3(x)))\n        return x\n\n\nclass T5RelativeEmbedding(nn.Module):\n\n    def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):\n        super(T5RelativeEmbedding, self).__init__()\n        self.num_buckets = num_buckets\n        self.num_heads = num_heads\n        self.bidirectional = bidirectional\n        self.max_dist = max_dist\n\n        # layers\n        self.embedding = nn.Embedding(num_buckets, num_heads)\n\n    def forward(self, lq, lk):\n        device = self.embedding.weight.device\n        # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \\\n        #     torch.arange(lq).unsqueeze(1).to(device)\n        rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \\\n            torch.arange(lq, device=device).unsqueeze(1)\n        rel_pos = self._relative_position_bucket(rel_pos)\n        rel_pos_embeds = self.embedding(rel_pos)\n        rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(\n            0)  # [1, N, Lq, Lk]\n        return rel_pos_embeds.contiguous()\n\n    def _relative_position_bucket(self, rel_pos):\n        # preprocess\n        if self.bidirectional:\n            num_buckets = self.num_buckets // 2\n            rel_buckets = (rel_pos > 0).long() * num_buckets\n            rel_pos = torch.abs(rel_pos)\n        else:\n            num_buckets = self.num_buckets\n            rel_buckets = 0\n            rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))\n\n        # embeddings for small and large positions\n        max_exact = num_buckets // 2\n        rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /\n                                     math.log(self.max_dist / max_exact) *\n                                     (num_buckets - max_exact)).long()\n        rel_pos_large = torch.min(\n            rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))\n        rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)\n        return rel_buckets\n\n\nclass T5Encoder(nn.Module):\n\n    def __init__(self,\n                 vocab,\n                 dim,\n                 dim_attn,\n                 dim_ffn,\n                 num_heads,\n                 num_layers,\n                 num_buckets,\n                 shared_pos=True,\n                 dropout=0.1):\n        super(T5Encoder, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.num_buckets = num_buckets\n        self.shared_pos = shared_pos\n\n        # layers\n        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \\\n            else nn.Embedding(vocab, dim)\n        self.pos_embedding = T5RelativeEmbedding(\n            num_buckets, num_heads, bidirectional=True) if shared_pos else None\n        self.dropout = nn.Dropout(dropout)\n        self.blocks = nn.ModuleList([\n            T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,\n                            shared_pos, dropout) for _ in range(num_layers)\n        ])\n        self.norm = T5LayerNorm(dim)\n\n        # initialize weights\n        self.apply(init_weights)\n\n    def forward(self, ids, mask=None):\n        x = self.token_embedding(ids)\n        x = self.dropout(x)\n        e = self.pos_embedding(x.size(1),\n                               x.size(1)) if self.shared_pos else None\n        for block in self.blocks:\n            x = block(x, mask, pos_bias=e)\n        x = self.norm(x)\n        x = self.dropout(x)\n        return x\n\n\nclass T5Decoder(nn.Module):\n\n    def __init__(self,\n                 vocab,\n                 dim,\n                 dim_attn,\n                 dim_ffn,\n                 num_heads,\n                 num_layers,\n                 num_buckets,\n                 shared_pos=True,\n                 dropout=0.1):\n        super(T5Decoder, self).__init__()\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.num_buckets = num_buckets\n        self.shared_pos = shared_pos\n\n        # layers\n        self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \\\n            else nn.Embedding(vocab, dim)\n        self.pos_embedding = T5RelativeEmbedding(\n            num_buckets, num_heads, bidirectional=False) if shared_pos else None\n        self.dropout = nn.Dropout(dropout)\n        self.blocks = nn.ModuleList([\n            T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,\n                             shared_pos, dropout) for _ in range(num_layers)\n        ])\n        self.norm = T5LayerNorm(dim)\n\n        # initialize weights\n        self.apply(init_weights)\n\n    def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):\n        b, s = ids.size()\n\n        # causal mask\n        if mask is None:\n            mask = torch.tril(torch.ones(1, s, s).to(ids.device))\n        elif mask.ndim == 2:\n            mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))\n\n        # layers\n        x = self.token_embedding(ids)\n        x = self.dropout(x)\n        e = self.pos_embedding(x.size(1),\n                               x.size(1)) if self.shared_pos else None\n        for block in self.blocks:\n            x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)\n        x = self.norm(x)\n        x = self.dropout(x)\n        return x\n\n\nclass T5Model(nn.Module):\n\n    def __init__(self,\n                 vocab_size,\n                 dim,\n                 dim_attn,\n                 dim_ffn,\n                 num_heads,\n                 encoder_layers,\n                 decoder_layers,\n                 num_buckets,\n                 shared_pos=True,\n                 dropout=0.1):\n        super(T5Model, self).__init__()\n        self.vocab_size = vocab_size\n        self.dim = dim\n        self.dim_attn = dim_attn\n        self.dim_ffn = dim_ffn\n        self.num_heads = num_heads\n        self.encoder_layers = encoder_layers\n        self.decoder_layers = decoder_layers\n        self.num_buckets = num_buckets\n\n        # layers\n        self.token_embedding = nn.Embedding(vocab_size, dim)\n        self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,\n                                 num_heads, encoder_layers, num_buckets,\n                                 shared_pos, dropout)\n        self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,\n                                 num_heads, decoder_layers, num_buckets,\n                                 shared_pos, dropout)\n        self.head = nn.Linear(dim, vocab_size, bias=False)\n\n        # initialize weights\n        self.apply(init_weights)\n\n    def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):\n        x = self.encoder(encoder_ids, encoder_mask)\n        x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)\n        x = self.head(x)\n        return x\n\n\ndef _t5(name,\n        encoder_only=False,\n        decoder_only=False,\n        return_tokenizer=False,\n        tokenizer_kwargs={},\n        dtype=torch.float32,\n        device='cpu',\n        **kwargs):\n    # sanity check\n    assert not (encoder_only and decoder_only)\n\n    # params\n    if encoder_only:\n        model_cls = T5Encoder\n        kwargs['vocab'] = kwargs.pop('vocab_size')\n        kwargs['num_layers'] = kwargs.pop('encoder_layers')\n        _ = kwargs.pop('decoder_layers')\n    elif decoder_only:\n        model_cls = T5Decoder\n        kwargs['vocab'] = kwargs.pop('vocab_size')\n        kwargs['num_layers'] = kwargs.pop('decoder_layers')\n        _ = kwargs.pop('encoder_layers')\n    else:\n        model_cls = T5Model\n\n    # init model\n    with torch.device(device):\n        model = model_cls(**kwargs)\n\n    # set device\n    model = model.to(dtype=dtype, device=device)\n\n    # init tokenizer\n    if return_tokenizer:\n        from .tokenizers import HuggingfaceTokenizer\n        tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)\n        return model, tokenizer\n    else:\n        return model\n\n\ndef umt5_xxl(**kwargs):\n    cfg = dict(\n        vocab_size=256384,\n        dim=4096,\n        dim_attn=4096,\n        dim_ffn=10240,\n        num_heads=64,\n        encoder_layers=24,\n        decoder_layers=24,\n        num_buckets=32,\n        shared_pos=False,\n        dropout=0.1)\n    cfg.update(**kwargs)\n    return _t5('umt5-xxl', **cfg)\n\n\nclass T5EncoderModel:\n\n    def __init__(\n        self,\n        text_len,\n        dtype=torch.bfloat16,\n        device=torch.cuda.current_device(),\n        checkpoint_path=None,\n        tokenizer_path=None,\n        shard_fn=None,\n    ):\n        self.text_len = text_len\n        self.dtype = dtype\n        self.device = device\n        self.checkpoint_path = checkpoint_path\n        self.tokenizer_path = tokenizer_path\n\n        # init model\n        model = umt5_xxl(\n            encoder_only=True,\n            return_tokenizer=False,\n            dtype=dtype,\n            device=device).eval().requires_grad_(False)\n        logging.info(f'loading {checkpoint_path}')\n        model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n        self.model = model\n        if shard_fn is not None:\n            self.model = shard_fn(self.model, sync_module_states=False)\n        else:\n            self.model.to(self.device)\n        # init tokenizer\n        self.tokenizer = HuggingfaceTokenizer(\n            name=tokenizer_path, seq_len=text_len, clean='whitespace')\n\n    def __call__(self, texts, device):\n        ids, mask = self.tokenizer(\n            texts, return_mask=True, add_special_tokens=True)\n        ids = ids.to(device)\n        mask = mask.to(device)\n        seq_lens = mask.gt(0).sum(dim=1).long()\n        context = self.model(ids, mask)\n        return [u[:v] for u, v in zip(context, seq_lens)]\n"
  },
  {
    "path": "wan/modules/tokenizers.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport regex as re\nfrom transformers import AutoTokenizer\n\n__all__ = ['HuggingfaceTokenizer']\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r'\\s+', ' ', text)\n    text = text.strip()\n    return text\n\n\ndef canonicalize(text, keep_punctuation_exact_string=None):\n    text = text.replace('_', ' ')\n    if keep_punctuation_exact_string:\n        text = keep_punctuation_exact_string.join(\n            part.translate(str.maketrans('', '', string.punctuation))\n            for part in text.split(keep_punctuation_exact_string))\n    else:\n        text = text.translate(str.maketrans('', '', string.punctuation))\n    text = text.lower()\n    text = re.sub(r'\\s+', ' ', text)\n    return text.strip()\n\n\nclass HuggingfaceTokenizer:\n\n    def __init__(self, name, seq_len=None, clean=None, **kwargs):\n        assert clean in (None, 'whitespace', 'lower', 'canonicalize')\n        self.name = name\n        self.seq_len = seq_len\n        self.clean = clean\n\n        # init tokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)\n        self.vocab_size = self.tokenizer.vocab_size\n\n    def __call__(self, sequence, **kwargs):\n        return_mask = kwargs.pop('return_mask', False)\n\n        # arguments\n        _kwargs = {'return_tensors': 'pt'}\n        if self.seq_len is not None:\n            _kwargs.update({\n                'padding': 'max_length',\n                'truncation': True,\n                'max_length': self.seq_len\n            })\n        _kwargs.update(**kwargs)\n\n        # tokenization\n        if isinstance(sequence, str):\n            sequence = [sequence]\n        if self.clean:\n            sequence = [self._clean(u) for u in sequence]\n        ids = self.tokenizer(sequence, **_kwargs)\n\n        # output\n        if return_mask:\n            return ids.input_ids, ids.attention_mask\n        else:\n            return ids.input_ids\n\n    def _clean(self, text):\n        if self.clean == 'whitespace':\n            text = whitespace_clean(basic_clean(text))\n        elif self.clean == 'lower':\n            text = whitespace_clean(basic_clean(text)).lower()\n        elif self.clean == 'canonicalize':\n            text = canonicalize(basic_clean(text))\n        return text\n"
  },
  {
    "path": "wan/modules/vae2_1.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n__all__ = [\n    'Wan2_1_VAE',\n]\n\nCACHE_T = 2\n\n\nclass CausalConv3d(nn.Conv3d):\n    \"\"\"\n    Causal 3d convolusion.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._padding = (self.padding[2], self.padding[2], self.padding[1],\n                         self.padding[1], 2 * self.padding[0], 0)\n        self.padding = (0, 0, 0)\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n        x = F.pad(x, padding)\n\n        return super().forward(x)\n\n\nclass RMS_norm(nn.Module):\n\n    def __init__(self, dim, channel_first=True, images=True, bias=False):\n        super().__init__()\n        broadcastable_dims = (1, 1, 1) if not images else (1, 1)\n        shape = (dim, *broadcastable_dims) if channel_first else (dim,)\n\n        self.channel_first = channel_first\n        self.scale = dim**0.5\n        self.gamma = nn.Parameter(torch.ones(shape))\n        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.\n\n    def forward(self, x):\n        return F.normalize(\n            x, dim=(1 if self.channel_first else\n                    -1)) * self.scale * self.gamma + self.bias\n\n\nclass Upsample(nn.Upsample):\n\n    def forward(self, x):\n        \"\"\"\n        Fix bfloat16 support for nearest neighbor interpolation.\n        \"\"\"\n        return super().forward(x.float()).type_as(x)\n\n\nclass Resample(nn.Module):\n\n    def __init__(self, dim, mode):\n        assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',\n                        'downsample3d')\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # layers\n        if mode == 'upsample2d':\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2., 2.), mode='nearest-exact'),\n                nn.Conv2d(dim, dim // 2, 3, padding=1))\n        elif mode == 'upsample3d':\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2., 2.), mode='nearest-exact'),\n                nn.Conv2d(dim, dim // 2, 3, padding=1))\n            self.time_conv = CausalConv3d(\n                dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))\n\n        elif mode == 'downsample2d':\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)),\n                nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n        elif mode == 'downsample3d':\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)),\n                nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n            self.time_conv = CausalConv3d(\n                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))\n\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        b, c, t, h, w = x.size()\n        if self.mode == 'upsample3d':\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = 'Rep'\n                    feat_idx[0] += 1\n                else:\n\n                    cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                    if cache_x.shape[2] < 2 and feat_cache[\n                            idx] is not None and feat_cache[idx] != 'Rep':\n                        # cache last frame of last two chunk\n                        cache_x = torch.cat([\n                            feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                                cache_x.device), cache_x\n                        ],\n                                            dim=2)\n                    if cache_x.shape[2] < 2 and feat_cache[\n                            idx] is not None and feat_cache[idx] == 'Rep':\n                        cache_x = torch.cat([\n                            torch.zeros_like(cache_x).to(cache_x.device),\n                            cache_x\n                        ],\n                                            dim=2)\n                    if feat_cache[idx] == 'Rep':\n                        x = self.time_conv(x)\n                    else:\n                        x = self.time_conv(x, feat_cache[idx])\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n\n                    x = x.reshape(b, 2, c, t, h, w)\n                    x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),\n                                    3)\n                    x = x.reshape(b, c, t * 2, h, w)\n        t = x.shape[2]\n        x = rearrange(x, 'b c t h w -> (b t) c h w')\n        x = self.resample(x)\n        x = rearrange(x, '(b t) c h w -> b c t h w', t=t)\n\n        if self.mode == 'downsample3d':\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = x.clone()\n                    feat_idx[0] += 1\n                else:\n\n                    cache_x = x[:, :, -1:, :, :].clone()\n                    # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':\n                    #     # cache last frame of last two chunk\n                    #     cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)\n\n                    x = self.time_conv(\n                        torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n        return x\n\n    def init_weight(self, conv):\n        conv_weight = conv.weight\n        nn.init.zeros_(conv_weight)\n        c1, c2, t, h, w = conv_weight.size()\n        one_matrix = torch.eye(c1, c2)\n        init_matrix = one_matrix\n        nn.init.zeros_(conv_weight)\n        #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5\n        conv_weight.data[:, :, 1, 0, 0] = init_matrix  #* 0.5\n        conv.weight.data.copy_(conv_weight)\n        nn.init.zeros_(conv.bias.data)\n\n    def init_weight2(self, conv):\n        conv_weight = conv.weight.data\n        nn.init.zeros_(conv_weight)\n        c1, c2, t, h, w = conv_weight.size()\n        init_matrix = torch.eye(c1 // 2, c2)\n        #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)\n        conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix\n        conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix\n        conv.weight.data.copy_(conv_weight)\n        nn.init.zeros_(conv.bias.data)\n\n\nclass ResidualBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, dropout=0.0):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # layers\n        self.residual = nn.Sequential(\n            RMS_norm(in_dim, images=False), nn.SiLU(),\n            CausalConv3d(in_dim, out_dim, 3, padding=1),\n            RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),\n            CausalConv3d(out_dim, out_dim, 3, padding=1))\n        self.shortcut = CausalConv3d(in_dim, out_dim, 1) \\\n            if in_dim != out_dim else nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        h = self.shortcut(x)\n        for layer in self.residual:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat([\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device), cache_x\n                    ],\n                                        dim=2)\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    Causal self-attention with a single head.\n    \"\"\"\n\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = RMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n\n        # zero out the last layer params\n        nn.init.zeros_(self.proj.weight)\n\n    def forward(self, x):\n        identity = x\n        b, c, t, h, w = x.size()\n        x = rearrange(x, 'b c t h w -> (b t) c h w')\n        x = self.norm(x)\n        # compute query, key, value\n        q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,\n                                         -1).permute(0, 1, 3,\n                                                     2).contiguous().chunk(\n                                                         3, dim=-1)\n\n        # apply attention\n        x = F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n        )\n        x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)\n\n        # output\n        x = self.proj(x)\n        x = rearrange(x, '(b t) c h w-> b c t h w', t=t)\n        return x + identity\n\n\nclass Encoder3d(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_downsample=[True, True, False],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        # init block\n        self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)\n\n        # downsample blocks\n        downsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            for _ in range(num_res_blocks):\n                downsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n                if scale in attn_scales:\n                    downsamples.append(AttentionBlock(out_dim))\n                in_dim = out_dim\n\n            # downsample block\n            if i != len(dim_mult) - 1:\n                mode = 'downsample3d' if temperal_downsample[\n                    i] else 'downsample2d'\n                downsamples.append(Resample(out_dim, mode=mode))\n                scale /= 2.0\n        self.downsamples = nn.Sequential(*downsamples)\n\n        # middle blocks\n        self.middle = nn.Sequential(\n            ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),\n            ResidualBlock(out_dim, out_dim, dropout))\n\n        # output blocks\n        self.head = nn.Sequential(\n            RMS_norm(out_dim, images=False), nn.SiLU(),\n            CausalConv3d(out_dim, z_dim, 3, padding=1))\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([\n                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                        cache_x.device), cache_x\n                ],\n                                    dim=2)\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        ## downsamples\n        for layer in self.downsamples:\n            if feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## middle\n        for layer in self.middle:\n            if isinstance(layer, ResidualBlock) and feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat([\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device), cache_x\n                    ],\n                                        dim=2)\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x\n\n\nclass Decoder3d(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_upsample=[False, True, True],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_upsample = temperal_upsample\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n        scale = 1.0 / 2**(len(dim_mult) - 2)\n\n        # init block\n        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.middle = nn.Sequential(\n            ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),\n            ResidualBlock(dims[0], dims[0], dropout))\n\n        # upsample blocks\n        upsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            # residual (+attention) blocks\n            if i == 1 or i == 2 or i == 3:\n                in_dim = in_dim // 2\n            for _ in range(num_res_blocks + 1):\n                upsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n                if scale in attn_scales:\n                    upsamples.append(AttentionBlock(out_dim))\n                in_dim = out_dim\n\n            # upsample block\n            if i != len(dim_mult) - 1:\n                mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'\n                upsamples.append(Resample(out_dim, mode=mode))\n                scale *= 2.0\n        self.upsamples = nn.Sequential(*upsamples)\n\n        # output blocks\n        self.head = nn.Sequential(\n            RMS_norm(out_dim, images=False), nn.SiLU(),\n            CausalConv3d(out_dim, 3, 3, padding=1))\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        ## conv1\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                # cache last frame of last two chunk\n                cache_x = torch.cat([\n                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                        cache_x.device), cache_x\n                ],\n                                    dim=2)\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        ## middle\n        for layer in self.middle:\n            if isinstance(layer, ResidualBlock) and feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## upsamples\n        for layer in self.upsamples:\n            if feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat([\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device), cache_x\n                    ],\n                                        dim=2)\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x\n\n\ndef count_conv3d(model):\n    count = 0\n    for m in model.modules():\n        if isinstance(m, CausalConv3d):\n            count += 1\n    return count\n\n\nclass WanVAE_(nn.Module):\n\n    def __init__(self,\n                 dim=128,\n                 z_dim=4,\n                 dim_mult=[1, 2, 4, 4],\n                 num_res_blocks=2,\n                 attn_scales=[],\n                 temperal_downsample=[True, True, False],\n                 dropout=0.0):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n        self.temperal_upsample = temperal_downsample[::-1]\n\n        # modules\n        self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,\n                                 attn_scales, self.temperal_downsample, dropout)\n        self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)\n        self.conv2 = CausalConv3d(z_dim, z_dim, 1)\n        self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,\n                                 attn_scales, self.temperal_upsample, dropout)\n\n    def forward(self, x):\n        mu, log_var = self.encode(x)\n        z = self.reparameterize(mu, log_var)\n        x_recon = self.decode(z)\n        return x_recon, mu, log_var\n\n    def encode(self, x, scale):\n        self.clear_cache()\n        ## cache\n        t = x.shape[2]\n        iter_ = 1 + (t - 1) // 4\n        ## 对encode输入的x，按时间拆分为1、4、4、4....\n        for i in range(iter_):\n            self._enc_conv_idx = [0]\n            if i == 0:\n                out = self.encoder(\n                    x[:, :, :1, :, :],\n                    feat_cache=self._enc_feat_map,\n                    feat_idx=self._enc_conv_idx)\n            else:\n                out_ = self.encoder(\n                    x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],\n                    feat_cache=self._enc_feat_map,\n                    feat_idx=self._enc_conv_idx)\n                out = torch.cat([out, out_], 2)\n        mu, log_var = self.conv1(out).chunk(2, dim=1)\n        if isinstance(scale[0], torch.Tensor):\n            mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            mu = (mu - scale[0]) * scale[1]\n        self.clear_cache()\n        return mu\n\n    def decode(self, z, scale):\n        self.clear_cache()\n        # z: [b,c,t,h,w]\n        if isinstance(scale[0], torch.Tensor):\n            z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            z = z / scale[1] + scale[0]\n        iter_ = z.shape[2]\n        x = self.conv2(z)\n        for i in range(iter_):\n            self._conv_idx = [0]\n            if i == 0:\n                out = self.decoder(\n                    x[:, :, i:i + 1, :, :],\n                    feat_cache=self._feat_map,\n                    feat_idx=self._conv_idx)\n            else:\n                out_ = self.decoder(\n                    x[:, :, i:i + 1, :, :],\n                    feat_cache=self._feat_map,\n                    feat_idx=self._conv_idx)\n                out = torch.cat([out, out_], 2)\n        self.clear_cache()\n        return out\n\n    def reparameterize(self, mu, log_var):\n        std = torch.exp(0.5 * log_var)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def sample(self, imgs, deterministic=False):\n        mu, log_var = self.encode(imgs)\n        if deterministic:\n            return mu\n        std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))\n        return mu + std * torch.randn_like(std)\n\n    def clear_cache(self):\n        self._conv_num = count_conv3d(self.decoder)\n        self._conv_idx = [0]\n        self._feat_map = [None] * self._conv_num\n        #cache encode\n        self._enc_conv_num = count_conv3d(self.encoder)\n        self._enc_conv_idx = [0]\n        self._enc_feat_map = [None] * self._enc_conv_num\n\n\ndef _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):\n    \"\"\"\n    Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.\n    \"\"\"\n    # params\n    cfg = dict(\n        dim=96,\n        z_dim=z_dim,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_downsample=[False, True, True],\n        dropout=0.0)\n    cfg.update(**kwargs)\n\n    # init model\n    with torch.device('meta'):\n        model = WanVAE_(**cfg)\n\n    # load checkpoint\n    logging.info(f'loading {pretrained_path}')\n    model.load_state_dict(\n        torch.load(pretrained_path, map_location=device), assign=True)\n\n    return model\n\n\nclass Wan2_1_VAE:\n\n    def __init__(self,\n                 z_dim=16,\n                 vae_pth='cache/vae_step_411000.pth',\n                 dtype=torch.float,\n                 device=\"cuda\"):\n        self.dtype = dtype\n        self.device = device\n\n        mean = [\n            -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,\n            0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921\n        ]\n        std = [\n            2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,\n            3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160\n        ]\n        self.mean = torch.tensor(mean, dtype=dtype, device=device)\n        self.std = torch.tensor(std, dtype=dtype, device=device)\n        self.scale = [self.mean, 1.0 / self.std]\n\n        # init model\n        self.model = _video_vae(\n            pretrained_path=vae_pth,\n            z_dim=z_dim,\n        ).eval().requires_grad_(False).to(device)\n\n    def encode(self, videos):\n        \"\"\"\n        videos: A list of videos each with shape [C, T, H, W].\n        \"\"\"\n        with amp.autocast(dtype=self.dtype):\n            return [\n                self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)\n                for u in videos\n            ]\n\n    def decode(self, zs):\n        with amp.autocast(dtype=self.dtype):\n            return [\n                self.model.decode(u.unsqueeze(0),\n                                  self.scale).float().clamp_(-1, 1).squeeze(0)\n                for u in zs\n            ]\n"
  },
  {
    "path": "wan/modules/vae2_2.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n__all__ = [\n    \"Wan2_2_VAE\",\n]\n\nCACHE_T = 2\n\n\nclass CausalConv3d(nn.Conv3d):\n    \"\"\"\n    Causal 3d convolusion.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._padding = (\n            self.padding[2],\n            self.padding[2],\n            self.padding[1],\n            self.padding[1],\n            2 * self.padding[0],\n            0,\n        )\n        self.padding = (0, 0, 0)\n\n    def forward(self, x, cache_x=None):\n        padding = list(self._padding)\n        if cache_x is not None and self._padding[4] > 0:\n            cache_x = cache_x.to(x.device)\n            x = torch.cat([cache_x, x], dim=2)\n            padding[4] -= cache_x.shape[2]\n        x = F.pad(x, padding)\n\n        return super().forward(x)\n\n\nclass RMS_norm(nn.Module):\n\n    def __init__(self, dim, channel_first=True, images=True, bias=False):\n        super().__init__()\n        broadcastable_dims = (1, 1, 1) if not images else (1, 1)\n        shape = (dim, *broadcastable_dims) if channel_first else (dim,)\n\n        self.channel_first = channel_first\n        self.scale = dim**0.5\n        self.gamma = nn.Parameter(torch.ones(shape))\n        self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0\n\n    def forward(self, x):\n        return (F.normalize(x, dim=(1 if self.channel_first else -1)) *\n                self.scale * self.gamma + self.bias)\n\n\nclass Upsample(nn.Upsample):\n\n    def forward(self, x):\n        \"\"\"\n        Fix bfloat16 support for nearest neighbor interpolation.\n        \"\"\"\n        return super().forward(x.float()).type_as(x)\n\n\nclass Resample(nn.Module):\n\n    def __init__(self, dim, mode):\n        assert mode in (\n            \"none\",\n            \"upsample2d\",\n            \"upsample3d\",\n            \"downsample2d\",\n            \"downsample3d\",\n        )\n        super().__init__()\n        self.dim = dim\n        self.mode = mode\n\n        # layers\n        if mode == \"upsample2d\":\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, dim, 3, padding=1),\n            )\n        elif mode == \"upsample3d\":\n            self.resample = nn.Sequential(\n                Upsample(scale_factor=(2.0, 2.0), mode=\"nearest-exact\"),\n                nn.Conv2d(dim, dim, 3, padding=1),\n                # nn.Conv2d(dim, dim//2, 3, padding=1)\n            )\n            self.time_conv = CausalConv3d(\n                dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))\n        elif mode == \"downsample2d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)),\n                nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n        elif mode == \"downsample3d\":\n            self.resample = nn.Sequential(\n                nn.ZeroPad2d((0, 1, 0, 1)),\n                nn.Conv2d(dim, dim, 3, stride=(2, 2)))\n            self.time_conv = CausalConv3d(\n                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))\n        else:\n            self.resample = nn.Identity()\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        b, c, t, h, w = x.size()\n        if self.mode == \"upsample3d\":\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = \"Rep\"\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                    if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and\n                            feat_cache[idx] != \"Rep\"):\n                        # cache last frame of last two chunk\n                        cache_x = torch.cat(\n                            [\n                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                                    cache_x.device),\n                                cache_x,\n                            ],\n                            dim=2,\n                        )\n                    if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and\n                            feat_cache[idx] == \"Rep\"):\n                        cache_x = torch.cat(\n                            [\n                                torch.zeros_like(cache_x).to(cache_x.device),\n                                cache_x\n                            ],\n                            dim=2,\n                        )\n                    if feat_cache[idx] == \"Rep\":\n                        x = self.time_conv(x)\n                    else:\n                        x = self.time_conv(x, feat_cache[idx])\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n                    x = x.reshape(b, 2, c, t, h, w)\n                    x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),\n                                    3)\n                    x = x.reshape(b, c, t * 2, h, w)\n        t = x.shape[2]\n        x = rearrange(x, \"b c t h w -> (b t) c h w\")\n        x = self.resample(x)\n        x = rearrange(x, \"(b t) c h w -> b c t h w\", t=t)\n\n        if self.mode == \"downsample3d\":\n            if feat_cache is not None:\n                idx = feat_idx[0]\n                if feat_cache[idx] is None:\n                    feat_cache[idx] = x.clone()\n                    feat_idx[0] += 1\n                else:\n                    cache_x = x[:, :, -1:, :, :].clone()\n                    x = self.time_conv(\n                        torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))\n                    feat_cache[idx] = cache_x\n                    feat_idx[0] += 1\n        return x\n\n    def init_weight(self, conv):\n        conv_weight = conv.weight.detach().clone()\n        nn.init.zeros_(conv_weight)\n        c1, c2, t, h, w = conv_weight.size()\n        one_matrix = torch.eye(c1, c2)\n        init_matrix = one_matrix\n        nn.init.zeros_(conv_weight)\n        conv_weight.data[:, :, 1, 0, 0] = init_matrix  # * 0.5\n        conv.weight = nn.Parameter(conv_weight)\n        nn.init.zeros_(conv.bias.data)\n\n    def init_weight2(self, conv):\n        conv_weight = conv.weight.data.detach().clone()\n        nn.init.zeros_(conv_weight)\n        c1, c2, t, h, w = conv_weight.size()\n        init_matrix = torch.eye(c1 // 2, c2)\n        conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix\n        conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix\n        conv.weight = nn.Parameter(conv_weight)\n        nn.init.zeros_(conv.bias.data)\n\n\nclass ResidualBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, dropout=0.0):\n        super().__init__()\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # layers\n        self.residual = nn.Sequential(\n            RMS_norm(in_dim, images=False),\n            nn.SiLU(),\n            CausalConv3d(in_dim, out_dim, 3, padding=1),\n            RMS_norm(out_dim, images=False),\n            nn.SiLU(),\n            nn.Dropout(dropout),\n            CausalConv3d(out_dim, out_dim, 3, padding=1),\n        )\n        self.shortcut = (\n            CausalConv3d(in_dim, out_dim, 1)\n            if in_dim != out_dim else nn.Identity())\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        h = self.shortcut(x)\n        for layer in self.residual:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    # cache last frame of last two chunk\n                    cache_x = torch.cat(\n                        [\n                            feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                                cache_x.device),\n                            cache_x,\n                        ],\n                        dim=2,\n                    )\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    Causal self-attention with a single head.\n    \"\"\"\n\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n        # layers\n        self.norm = RMS_norm(dim)\n        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)\n        self.proj = nn.Conv2d(dim, dim, 1)\n\n        # zero out the last layer params\n        nn.init.zeros_(self.proj.weight)\n\n    def forward(self, x):\n        identity = x\n        b, c, t, h, w = x.size()\n        x = rearrange(x, \"b c t h w -> (b t) c h w\")\n        x = self.norm(x)\n        # compute query, key, value\n        q, k, v = (\n            self.to_qkv(x).reshape(b * t, 1, c * 3,\n                                   -1).permute(0, 1, 3,\n                                               2).contiguous().chunk(3, dim=-1))\n\n        # apply attention\n        x = F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n        )\n        x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)\n\n        # output\n        x = self.proj(x)\n        x = rearrange(x, \"(b t) c h w-> b c t h w\", t=t)\n        return x + identity\n\n\ndef patchify(x, patch_size):\n    if patch_size == 1:\n        return x\n    if x.dim() == 4:\n        x = rearrange(\n            x, \"b c (h q) (w r) -> b (c r q) h w\", q=patch_size, r=patch_size)\n    elif x.dim() == 5:\n        x = rearrange(\n            x,\n            \"b c f (h q) (w r) -> b (c r q) f h w\",\n            q=patch_size,\n            r=patch_size,\n        )\n    else:\n        raise ValueError(f\"Invalid input shape: {x.shape}\")\n\n    return x\n\n\ndef unpatchify(x, patch_size):\n    if patch_size == 1:\n        return x\n\n    if x.dim() == 4:\n        x = rearrange(\n            x, \"b (c r q) h w -> b c (h q) (w r)\", q=patch_size, r=patch_size)\n    elif x.dim() == 5:\n        x = rearrange(\n            x,\n            \"b (c r q) f h w -> b c f (h q) (w r)\",\n            q=patch_size,\n            r=patch_size,\n        )\n    return x\n\n\nclass AvgDown3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        factor_t,\n        factor_s=1,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.factor_t = factor_t\n        self.factor_s = factor_s\n        self.factor = self.factor_t * self.factor_s * self.factor_s\n\n        assert in_channels * self.factor % out_channels == 0\n        self.group_size = in_channels * self.factor // out_channels\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t\n        pad = (0, 0, 0, 0, pad_t, 0)\n        x = F.pad(x, pad)\n        B, C, T, H, W = x.shape\n        x = x.view(\n            B,\n            C,\n            T // self.factor_t,\n            self.factor_t,\n            H // self.factor_s,\n            self.factor_s,\n            W // self.factor_s,\n            self.factor_s,\n        )\n        x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()\n        x = x.view(\n            B,\n            C * self.factor,\n            T // self.factor_t,\n            H // self.factor_s,\n            W // self.factor_s,\n        )\n        x = x.view(\n            B,\n            self.out_channels,\n            self.group_size,\n            T // self.factor_t,\n            H // self.factor_s,\n            W // self.factor_s,\n        )\n        x = x.mean(dim=2)\n        return x\n\n\nclass DupUp3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        factor_t,\n        factor_s=1,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.factor_t = factor_t\n        self.factor_s = factor_s\n        self.factor = self.factor_t * self.factor_s * self.factor_s\n\n        assert out_channels * self.factor % in_channels == 0\n        self.repeats = out_channels * self.factor // in_channels\n\n    def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:\n        x = x.repeat_interleave(self.repeats, dim=1)\n        x = x.view(\n            x.size(0),\n            self.out_channels,\n            self.factor_t,\n            self.factor_s,\n            self.factor_s,\n            x.size(2),\n            x.size(3),\n            x.size(4),\n        )\n        x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()\n        x = x.view(\n            x.size(0),\n            self.out_channels,\n            x.size(2) * self.factor_t,\n            x.size(4) * self.factor_s,\n            x.size(6) * self.factor_s,\n        )\n        if first_chunk:\n            x = x[:, :, self.factor_t - 1:, :, :]\n        return x\n\n\nclass Down_ResidualBlock(nn.Module):\n\n    def __init__(self,\n                 in_dim,\n                 out_dim,\n                 dropout,\n                 mult,\n                 temperal_downsample=False,\n                 down_flag=False):\n        super().__init__()\n\n        # Shortcut path with downsample\n        self.avg_shortcut = AvgDown3D(\n            in_dim,\n            out_dim,\n            factor_t=2 if temperal_downsample else 1,\n            factor_s=2 if down_flag else 1,\n        )\n\n        # Main path with residual blocks and downsample\n        downsamples = []\n        for _ in range(mult):\n            downsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n            in_dim = out_dim\n\n        # Add the final downsample block\n        if down_flag:\n            mode = \"downsample3d\" if temperal_downsample else \"downsample2d\"\n            downsamples.append(Resample(out_dim, mode=mode))\n\n        self.downsamples = nn.Sequential(*downsamples)\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n        x_copy = x.clone()\n        for module in self.downsamples:\n            x = module(x, feat_cache, feat_idx)\n\n        return x + self.avg_shortcut(x_copy)\n\n\nclass Up_ResidualBlock(nn.Module):\n\n    def __init__(self,\n                 in_dim,\n                 out_dim,\n                 dropout,\n                 mult,\n                 temperal_upsample=False,\n                 up_flag=False):\n        super().__init__()\n        # Shortcut path with upsample\n        if up_flag:\n            self.avg_shortcut = DupUp3D(\n                in_dim,\n                out_dim,\n                factor_t=2 if temperal_upsample else 1,\n                factor_s=2 if up_flag else 1,\n            )\n        else:\n            self.avg_shortcut = None\n\n        # Main path with residual blocks and upsample\n        upsamples = []\n        for _ in range(mult):\n            upsamples.append(ResidualBlock(in_dim, out_dim, dropout))\n            in_dim = out_dim\n\n        # Add the final upsample block\n        if up_flag:\n            mode = \"upsample3d\" if temperal_upsample else \"upsample2d\"\n            upsamples.append(Resample(out_dim, mode=mode))\n\n        self.upsamples = nn.Sequential(*upsamples)\n\n    def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):\n        x_main = x.clone()\n        for module in self.upsamples:\n            x_main = module(x_main, feat_cache, feat_idx)\n        if self.avg_shortcut is not None:\n            x_shortcut = self.avg_shortcut(x, first_chunk)\n            return x_main + x_shortcut\n        else:\n            return x_main\n\n\nclass Encoder3d(nn.Module):\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_downsample=[True, True, False],\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n\n        # dimensions\n        dims = [dim * u for u in [1] + dim_mult]\n        scale = 1.0\n\n        # init block\n        self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)\n\n        # downsample blocks\n        downsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            t_down_flag = (\n                temperal_downsample[i]\n                if i < len(temperal_downsample) else False)\n            downsamples.append(\n                Down_ResidualBlock(\n                    in_dim=in_dim,\n                    out_dim=out_dim,\n                    dropout=dropout,\n                    mult=num_res_blocks,\n                    temperal_downsample=t_down_flag,\n                    down_flag=i != len(dim_mult) - 1,\n                ))\n            scale /= 2.0\n        self.downsamples = nn.Sequential(*downsamples)\n\n        # middle blocks\n        self.middle = nn.Sequential(\n            ResidualBlock(out_dim, out_dim, dropout),\n            AttentionBlock(out_dim),\n            ResidualBlock(out_dim, out_dim, dropout),\n        )\n\n        # # output blocks\n        self.head = nn.Sequential(\n            RMS_norm(out_dim, images=False),\n            nn.SiLU(),\n            CausalConv3d(out_dim, z_dim, 3, padding=1),\n        )\n\n    def forward(self, x, feat_cache=None, feat_idx=[0]):\n\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        ## downsamples\n        for layer in self.downsamples:\n            if feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## middle\n        for layer in self.middle:\n            if isinstance(layer, ResidualBlock) and feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    cache_x = torch.cat(\n                        [\n                            feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                                cache_x.device),\n                            cache_x,\n                        ],\n                        dim=2,\n                    )\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n\n        return x\n\n\nclass Decoder3d(nn.Module):\n\n    def __init__(\n        self,\n        dim=128,\n        z_dim=4,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_upsample=[False, True, True],\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_upsample = temperal_upsample\n\n        # dimensions\n        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]\n        scale = 1.0 / 2**(len(dim_mult) - 2)\n        # init block\n        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)\n\n        # middle blocks\n        self.middle = nn.Sequential(\n            ResidualBlock(dims[0], dims[0], dropout),\n            AttentionBlock(dims[0]),\n            ResidualBlock(dims[0], dims[0], dropout),\n        )\n\n        # upsample blocks\n        upsamples = []\n        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):\n            t_up_flag = temperal_upsample[i] if i < len(\n                temperal_upsample) else False\n            upsamples.append(\n                Up_ResidualBlock(\n                    in_dim=in_dim,\n                    out_dim=out_dim,\n                    dropout=dropout,\n                    mult=num_res_blocks + 1,\n                    temperal_upsample=t_up_flag,\n                    up_flag=i != len(dim_mult) - 1,\n                ))\n        self.upsamples = nn.Sequential(*upsamples)\n\n        # output blocks\n        self.head = nn.Sequential(\n            RMS_norm(out_dim, images=False),\n            nn.SiLU(),\n            CausalConv3d(out_dim, 12, 3, padding=1),\n        )\n\n    def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):\n        if feat_cache is not None:\n            idx = feat_idx[0]\n            cache_x = x[:, :, -CACHE_T:, :, :].clone()\n            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                cache_x = torch.cat(\n                    [\n                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                            cache_x.device),\n                        cache_x,\n                    ],\n                    dim=2,\n                )\n            x = self.conv1(x, feat_cache[idx])\n            feat_cache[idx] = cache_x\n            feat_idx[0] += 1\n        else:\n            x = self.conv1(x)\n\n        for layer in self.middle:\n            if isinstance(layer, ResidualBlock) and feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx)\n            else:\n                x = layer(x)\n\n        ## upsamples\n        for layer in self.upsamples:\n            if feat_cache is not None:\n                x = layer(x, feat_cache, feat_idx, first_chunk)\n            else:\n                x = layer(x)\n\n        ## head\n        for layer in self.head:\n            if isinstance(layer, CausalConv3d) and feat_cache is not None:\n                idx = feat_idx[0]\n                cache_x = x[:, :, -CACHE_T:, :, :].clone()\n                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:\n                    cache_x = torch.cat(\n                        [\n                            feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(\n                                cache_x.device),\n                            cache_x,\n                        ],\n                        dim=2,\n                    )\n                x = layer(x, feat_cache[idx])\n                feat_cache[idx] = cache_x\n                feat_idx[0] += 1\n            else:\n                x = layer(x)\n        return x\n\n\ndef count_conv3d(model):\n    count = 0\n    for m in model.modules():\n        if isinstance(m, CausalConv3d):\n            count += 1\n    return count\n\n\nclass WanVAE_(nn.Module):\n\n    def __init__(\n        self,\n        dim=160,\n        dec_dim=256,\n        z_dim=16,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_downsample=[True, True, False],\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.z_dim = z_dim\n        self.dim_mult = dim_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_scales = attn_scales\n        self.temperal_downsample = temperal_downsample\n        self.temperal_upsample = temperal_downsample[::-1]\n\n        # modules\n        self.encoder = Encoder3d(\n            dim,\n            z_dim * 2,\n            dim_mult,\n            num_res_blocks,\n            attn_scales,\n            self.temperal_downsample,\n            dropout,\n        )\n        self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)\n        self.conv2 = CausalConv3d(z_dim, z_dim, 1)\n        self.decoder = Decoder3d(\n            dec_dim,\n            z_dim,\n            dim_mult,\n            num_res_blocks,\n            attn_scales,\n            self.temperal_upsample,\n            dropout,\n        )\n\n    def forward(self, x, scale=[0, 1]):\n        mu = self.encode(x, scale)\n        x_recon = self.decode(mu, scale)\n        return x_recon, mu\n\n    def encode(self, x, scale):\n        self.clear_cache()\n        x = patchify(x, patch_size=2)\n        t = x.shape[2]\n        iter_ = 1 + (t - 1) // 4\n        for i in range(iter_):\n            self._enc_conv_idx = [0]\n            if i == 0:\n                out = self.encoder(\n                    x[:, :, :1, :, :],\n                    feat_cache=self._enc_feat_map,\n                    feat_idx=self._enc_conv_idx,\n                )\n            else:\n                out_ = self.encoder(\n                    x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],\n                    feat_cache=self._enc_feat_map,\n                    feat_idx=self._enc_conv_idx,\n                )\n                out = torch.cat([out, out_], 2)\n        mu, log_var = self.conv1(out).chunk(2, dim=1)\n        if isinstance(scale[0], torch.Tensor):\n            mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            mu = (mu - scale[0]) * scale[1]\n        self.clear_cache()\n        return mu\n\n    def decode(self, z, scale):\n        self.clear_cache()\n        if isinstance(scale[0], torch.Tensor):\n            z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(\n                1, self.z_dim, 1, 1, 1)\n        else:\n            z = z / scale[1] + scale[0]\n        iter_ = z.shape[2]\n        x = self.conv2(z)\n        for i in range(iter_):\n            self._conv_idx = [0]\n            if i == 0:\n                out = self.decoder(\n                    x[:, :, i:i + 1, :, :],\n                    feat_cache=self._feat_map,\n                    feat_idx=self._conv_idx,\n                    first_chunk=True,\n                )\n            else:\n                out_ = self.decoder(\n                    x[:, :, i:i + 1, :, :],\n                    feat_cache=self._feat_map,\n                    feat_idx=self._conv_idx,\n                )\n                out = torch.cat([out, out_], 2)\n        out = unpatchify(out, patch_size=2)\n        self.clear_cache()\n        return out\n\n    def reparameterize(self, mu, log_var):\n        std = torch.exp(0.5 * log_var)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def sample(self, imgs, deterministic=False):\n        mu, log_var = self.encode(imgs)\n        if deterministic:\n            return mu\n        std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))\n        return mu + std * torch.randn_like(std)\n\n    def clear_cache(self):\n        self._conv_num = count_conv3d(self.decoder)\n        self._conv_idx = [0]\n        self._feat_map = [None] * self._conv_num\n        # cache encode\n        self._enc_conv_num = count_conv3d(self.encoder)\n        self._enc_conv_idx = [0]\n        self._enc_feat_map = [None] * self._enc_conv_num\n\n\ndef _video_vae(pretrained_path=None, z_dim=16, dim=160, device=\"cpu\", **kwargs):\n    # params\n    cfg = dict(\n        dim=dim,\n        z_dim=z_dim,\n        dim_mult=[1, 2, 4, 4],\n        num_res_blocks=2,\n        attn_scales=[],\n        temperal_downsample=[True, True, True],\n        dropout=0.0,\n    )\n    cfg.update(**kwargs)\n\n    # init model\n    with torch.device(\"meta\"):\n        model = WanVAE_(**cfg)\n\n    # load checkpoint\n    logging.info(f\"loading {pretrained_path}\")\n    model.load_state_dict(\n        torch.load(pretrained_path, map_location=device), assign=True)\n\n    return model\n\n\nclass Wan2_2_VAE:\n\n    def __init__(\n        self,\n        z_dim=48,\n        c_dim=160,\n        vae_pth=None,\n        dim_mult=[1, 2, 4, 4],\n        temperal_downsample=[False, True, True],\n        dtype=torch.float,\n        device=\"cuda\",\n    ):\n\n        self.dtype = dtype\n        self.device = device\n\n        mean = torch.tensor(\n            [\n                -0.2289,\n                -0.0052,\n                -0.1323,\n                -0.2339,\n                -0.2799,\n                0.0174,\n                0.1838,\n                0.1557,\n                -0.1382,\n                0.0542,\n                0.2813,\n                0.0891,\n                0.1570,\n                -0.0098,\n                0.0375,\n                -0.1825,\n                -0.2246,\n                -0.1207,\n                -0.0698,\n                0.5109,\n                0.2665,\n                -0.2108,\n                -0.2158,\n                0.2502,\n                -0.2055,\n                -0.0322,\n                0.1109,\n                0.1567,\n                -0.0729,\n                0.0899,\n                -0.2799,\n                -0.1230,\n                -0.0313,\n                -0.1649,\n                0.0117,\n                0.0723,\n                -0.2839,\n                -0.2083,\n                -0.0520,\n                0.3748,\n                0.0152,\n                0.1957,\n                0.1433,\n                -0.2944,\n                0.3573,\n                -0.0548,\n                -0.1681,\n                -0.0667,\n            ],\n            dtype=dtype,\n            device=device,\n        )\n        std = torch.tensor(\n            [\n                0.4765,\n                1.0364,\n                0.4514,\n                1.1677,\n                0.5313,\n                0.4990,\n                0.4818,\n                0.5013,\n                0.8158,\n                1.0344,\n                0.5894,\n                1.0901,\n                0.6885,\n                0.6165,\n                0.8454,\n                0.4978,\n                0.5759,\n                0.3523,\n                0.7135,\n                0.6804,\n                0.5833,\n                1.4146,\n                0.8986,\n                0.5659,\n                0.7069,\n                0.5338,\n                0.4889,\n                0.4917,\n                0.4069,\n                0.4999,\n                0.6866,\n                0.4093,\n                0.5709,\n                0.6065,\n                0.6415,\n                0.4944,\n                0.5726,\n                1.2042,\n                0.5458,\n                1.6887,\n                0.3971,\n                1.0600,\n                0.3943,\n                0.5537,\n                0.5444,\n                0.4089,\n                0.7468,\n                0.7744,\n            ],\n            dtype=dtype,\n            device=device,\n        )\n        self.scale = [mean, 1.0 / std]\n\n        # init model\n        self.model = (\n            _video_vae(\n                pretrained_path=vae_pth,\n                z_dim=z_dim,\n                dim=c_dim,\n                dim_mult=dim_mult,\n                temperal_downsample=temperal_downsample,\n            ).eval().requires_grad_(False).to(device))\n\n    def encode(self, videos):\n        try:\n            if not isinstance(videos, list):\n                raise TypeError(\"videos should be a list\")\n            with amp.autocast(dtype=self.dtype):\n                return [\n                    self.model.encode(u.unsqueeze(0),\n                                      self.scale).float().squeeze(0)\n                    for u in videos\n                ]\n        except TypeError as e:\n            logging.info(e)\n            return None\n\n    def decode(self, zs):\n        try:\n            if not isinstance(zs, list):\n                raise TypeError(\"zs should be a list\")\n            with amp.autocast(dtype=self.dtype):\n                return [\n                    self.model.decode(u.unsqueeze(0),\n                                      self.scale).float().clamp_(-1,\n                                                                 1).squeeze(0)\n                    for u in zs\n                ]\n        except TypeError as e:\n            logging.info(e)\n            return None\n"
  },
  {
    "path": "wan/speech2video.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport types\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom functools import partial\n\nimport numpy as np\nimport torch\nimport torch.cuda.amp as amp\nimport torch.distributed as dist\nimport torchvision.transforms.functional as TF\nfrom decord import VideoReader\nfrom PIL import Image\nfrom safetensors import safe_open\nfrom torchvision import transforms\nfrom tqdm import tqdm\n\nfrom .distributed.fsdp import shard_model\nfrom .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward\nfrom .distributed.util import get_world_size\nfrom .modules.s2v.audio_encoder import AudioEncoder\nfrom .modules.s2v.model_s2v import WanModel_S2V, sp_attn_forward_s2v\nfrom .modules.t5 import T5EncoderModel\nfrom .modules.vae2_1 import Wan2_1_VAE\nfrom .utils.fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler\n\n\ndef load_safetensors(path):\n    tensors = {}\n    with safe_open(path, framework=\"pt\", device=\"cpu\") as f:\n        for key in f.keys():\n            tensors[key] = f.get_tensor(key)\n    return tensors\n\n\nclass WanS2V:\n\n    def __init__(\n        self,\n        config,\n        checkpoint_dir,\n        device_id=0,\n        rank=0,\n        t5_fsdp=False,\n        dit_fsdp=False,\n        use_sp=False,\n        t5_cpu=False,\n        init_on_cpu=True,\n        convert_model_dtype=False,\n    ):\n        r\"\"\"\n        Initializes the image-to-video generation model components.\n\n        Args:\n            config (EasyDict):\n                Object containing model parameters initialized from config.py\n            checkpoint_dir (`str`):\n                Path to directory containing model checkpoints\n            device_id (`int`,  *optional*, defaults to 0):\n                Id of target GPU device\n            rank (`int`,  *optional*, defaults to 0):\n                Process rank for distributed training\n            t5_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for T5 model\n            dit_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for DiT model\n            use_sp (`bool`, *optional*, defaults to False):\n                Enable distribution strategy of sequence parallel.\n            t5_cpu (`bool`, *optional*, defaults to False):\n                Whether to place T5 model on CPU. Only works without t5_fsdp.\n            init_on_cpu (`bool`, *optional*, defaults to True):\n                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.\n            convert_model_dtype (`bool`, *optional*, defaults to False):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n        \"\"\"\n        self.device = torch.device(f\"cuda:{device_id}\")\n        self.config = config\n        self.rank = rank\n        self.t5_cpu = t5_cpu\n        self.init_on_cpu = init_on_cpu\n\n        self.num_train_timesteps = config.num_train_timesteps\n        self.param_dtype = config.param_dtype\n\n        if t5_fsdp or dit_fsdp or use_sp:\n            self.init_on_cpu = False\n\n        shard_fn = partial(shard_model, device_id=device_id)\n        self.text_encoder = T5EncoderModel(\n            text_len=config.text_len,\n            dtype=config.t5_dtype,\n            device=torch.device('cpu'),\n            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),\n            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),\n            shard_fn=shard_fn if t5_fsdp else None,\n        )\n\n        self.vae = Wan2_1_VAE(\n            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),\n            device=self.device)\n\n        logging.info(f\"Creating WanModel from {checkpoint_dir}\")\n        if not dit_fsdp:\n            self.noise_model = WanModel_S2V.from_pretrained(\n                checkpoint_dir,\n                torch_dtype=self.param_dtype,\n                device_map=self.device)\n        else:\n            self.noise_model = WanModel_S2V.from_pretrained(\n                checkpoint_dir, torch_dtype=self.param_dtype)\n\n        self.noise_model = self._configure_model(\n            model=self.noise_model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype)\n\n        self.audio_encoder = AudioEncoder(\n            model_id=os.path.join(checkpoint_dir,\n                                  \"wav2vec2-large-xlsr-53-english\"))\n\n        if use_sp:\n            self.sp_size = get_world_size()\n        else:\n            self.sp_size = 1\n\n        self.sample_neg_prompt = config.sample_neg_prompt\n        self.motion_frames = config.transformer.motion_frames\n        self.drop_first_motion = config.drop_first_motion\n        self.fps = config.sample_fps\n        self.audio_sample_m = 0\n\n    def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,\n                         convert_model_dtype):\n        \"\"\"\n        Configures a model object. This includes setting evaluation modes,\n        applying distributed parallel strategy, and handling device placement.\n\n        Args:\n            model (torch.nn.Module):\n                The model instance to configure.\n            use_sp (`bool`):\n                Enable distribution strategy of sequence parallel.\n            dit_fsdp (`bool`):\n                Enable FSDP sharding for DiT model.\n            shard_fn (callable):\n                The function to apply FSDP sharding.\n            convert_model_dtype (`bool`):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n\n        Returns:\n            torch.nn.Module:\n                The configured model.\n        \"\"\"\n        model.eval().requires_grad_(False)\n        if use_sp:\n            for block in model.blocks:\n                block.self_attn.forward = types.MethodType(\n                    sp_attn_forward_s2v, block.self_attn)\n            model.use_context_parallel = True\n\n        if dist.is_initialized():\n            dist.barrier()\n\n        if dit_fsdp:\n            model = shard_fn(model)\n        else:\n            if convert_model_dtype:\n                model.to(self.param_dtype)\n            if not self.init_on_cpu:\n                model.to(self.device)\n\n        return model\n\n    def get_size_less_than_area(self,\n                                height,\n                                width,\n                                target_area=1024 * 704,\n                                divisor=64):\n        if height * width <= target_area:\n            # If the original image area is already less than or equal to the target,\n            # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.\n            max_upper_area = target_area\n            min_scale = 0.1\n            max_scale = 1.0\n        else:\n            # Resize to fit within the target area and then pad to multiples of `divisor`\n            max_upper_area = target_area  # Maximum allowed total pixel count after padding\n            d = divisor - 1\n            b = d * (height + width)\n            a = height * width\n            c = d**2 - max_upper_area\n\n            # Calculate scale boundaries using quadratic equation\n            min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (\n                2 * a)  # Scale when maximum padding is applied\n            max_scale = math.sqrt(max_upper_area /\n                                  (height * width))  # Scale without any padding\n\n        # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area\n        # Use binary search-like iteration to find this scale\n        find_it = False\n        for i in range(100):\n            scale = max_scale - (max_scale - min_scale) * i / 100\n            new_height, new_width = int(height * scale), int(width * scale)\n\n            # Pad to make dimensions divisible by 64\n            pad_height = (64 - new_height % 64) % 64\n            pad_width = (64 - new_width % 64) % 64\n            pad_top = pad_height // 2\n            pad_bottom = pad_height - pad_top\n            pad_left = pad_width // 2\n            pad_right = pad_width - pad_left\n\n            padded_height, padded_width = new_height + pad_height, new_width + pad_width\n\n            if padded_height * padded_width <= max_upper_area:\n                find_it = True\n                break\n\n        if find_it:\n            return padded_height, padded_width\n        else:\n            # Fallback: calculate target dimensions based on aspect ratio and divisor alignment\n            aspect_ratio = width / height\n            target_width = int(\n                (target_area * aspect_ratio)**0.5 // divisor * divisor)\n            target_height = int(\n                (target_area / aspect_ratio)**0.5 // divisor * divisor)\n\n            # Ensure the result is not larger than the original resolution\n            if target_width >= width or target_height >= height:\n                target_width = int(width // divisor * divisor)\n                target_height = int(height // divisor * divisor)\n\n            return target_height, target_width\n\n    def prepare_default_cond_input(self,\n                                   map_shape=[3, 12, 64, 64],\n                                   motion_frames=5,\n                                   lat_motion_frames=2,\n                                   enable_mano=False,\n                                   enable_kp=False,\n                                   enable_pose=False):\n        default_value = [1.0, -1.0, -1.0]\n        cond_enable = [enable_mano, enable_kp, enable_pose]\n        cond = []\n        for d, c in zip(default_value, cond_enable):\n            if c:\n                map_value = torch.ones(\n                    map_shape, dtype=self.param_dtype, device=self.device) * d\n                cond_lat = torch.cat([\n                    map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1),\n                    map_value\n                ],\n                                     dim=2)\n                cond_lat = torch.stack(\n                    self.vae.encode(cond_lat.to(\n                        self.param_dtype)))[:, :, lat_motion_frames:].to(\n                            self.param_dtype)\n\n                cond.append(cond_lat)\n        if len(cond) >= 1:\n            cond = torch.cat(cond, dim=1)\n        else:\n            cond = None\n        return cond\n\n    def encode_audio(self, audio_path, infer_frames):\n        z = self.audio_encoder.extract_audio_feat(\n            audio_path, return_all_layers=True)\n        audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps(\n            z, fps=self.fps, batch_frames=infer_frames, m=self.audio_sample_m)\n        audio_embed_bucket = audio_embed_bucket.to(self.device,\n                                                   self.param_dtype)\n        audio_embed_bucket = audio_embed_bucket.unsqueeze(0)\n        if len(audio_embed_bucket.shape) == 3:\n            audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)\n        elif len(audio_embed_bucket.shape) == 4:\n            audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)\n        return audio_embed_bucket, num_repeat\n\n    def read_last_n_frames(self,\n                           video_path,\n                           n_frames,\n                           target_fps=16,\n                           reverse=False):\n        \"\"\"\n        Read the last `n_frames` from a video at the specified frame rate.\n\n        Parameters:\n            video_path (str): Path to the video file.\n            n_frames (int): Number of frames to read.\n            target_fps (int, optional): Target sampling frame rate. Defaults to 16.\n            reverse (bool, optional): Whether to read frames in reverse order. \n                                    If True, reads the first `n_frames` instead of the last ones.\n\n        Returns:\n            np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames.\n        \"\"\"\n        vr = VideoReader(video_path)\n        original_fps = vr.get_avg_fps()\n        total_frames = len(vr)\n\n        interval = max(1, round(original_fps / target_fps))\n\n        required_span = (n_frames - 1) * interval\n\n        start_frame = max(0, total_frames - required_span -\n                          1) if not reverse else 0\n\n        sampled_indices = []\n        for i in range(n_frames):\n            indice = start_frame + i * interval\n            if indice >= total_frames:\n                break\n            else:\n                sampled_indices.append(indice)\n\n        return vr.get_batch(sampled_indices).asnumpy()\n\n    def load_pose_cond(self, pose_video, num_repeat, infer_frames, size):\n        HEIGHT, WIDTH = size\n        if not pose_video is None:\n            pose_seq = self.read_last_n_frames(\n                pose_video,\n                n_frames=infer_frames * num_repeat,\n                target_fps=self.fps,\n                reverse=True)\n\n            resize_opreat = transforms.Resize(min(HEIGHT, WIDTH))\n            crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH))\n            tensor_trans = transforms.ToTensor()\n\n            cond_tensor = torch.from_numpy(pose_seq)\n            cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0\n            cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute(\n                1, 0, 2, 3).unsqueeze(0)\n\n            padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2]\n            cond_tensor = torch.cat([\n                cond_tensor,\n                - torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])\n            ],\n                                    dim=2)\n\n            cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2)\n        else:\n            cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])]\n\n        COND = []\n        for r in range(len(cond_tensors)):\n            cond = cond_tensors[r]\n            cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond],\n                             dim=2)\n            cond_lat = torch.stack(\n                self.vae.encode(\n                    cond.to(dtype=self.param_dtype,\n                            device=self.device)))[:, :,\n                                                  1:].cpu()  # for mem save\n            COND.append(cond_lat)\n        return COND\n\n    def get_gen_size(self, size, max_area, ref_image_path, pre_video_path):\n        if not size is None:\n            HEIGHT, WIDTH = size\n        else:\n            if pre_video_path:\n                ref_image = self.read_last_n_frames(\n                    pre_video_path, n_frames=1)[0]\n            else:\n                ref_image = np.array(Image.open(ref_image_path).convert('RGB'))\n            HEIGHT, WIDTH = ref_image.shape[:2]\n        HEIGHT, WIDTH = self.get_size_less_than_area(\n            HEIGHT, WIDTH, target_area=max_area)\n        return (HEIGHT, WIDTH)\n\n    def generate(\n        self,\n        input_prompt,\n        ref_image_path,\n        audio_path,\n        enable_tts,\n        tts_prompt_audio,\n        tts_prompt_text,\n        tts_text,\n        num_repeat=1,\n        pose_video=None,\n        max_area=720 * 1280,\n        infer_frames=80,\n        shift=5.0,\n        sample_solver='unipc',\n        sampling_steps=40,\n        guide_scale=5.0,\n        n_prompt=\"\",\n        seed=-1,\n        offload_model=True,\n        init_first_frame=False,\n    ):\n        r\"\"\"\n        Generates video frames from input image and text prompt using diffusion process.\n\n        Args:\n            input_prompt (`str`):\n                Text prompt for content generation.\n            ref_image_path ('str'):\n                Input image path\n            audio_path ('str'):\n                Audio for video driven\n            num_repeat ('int'):\n                Number of clips to generate; will be automatically adjusted based on the audio length\n            pose_video ('str'):\n                If provided, uses a sequence of poses to drive the generated video\n            max_area (`int`, *optional*, defaults to 720*1280):\n                Maximum pixel area for latent space calculation. Controls video resolution scaling\n            infer_frames (`int`, *optional*, defaults to 80):\n                How many frames to generate per clips. The number should be 4n\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. Affects temporal dynamics\n                [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.\n            sample_solver (`str`, *optional*, defaults to 'unipc'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 40):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):\n                Classifier-free guidance scale. Controls prompt adherence vs. creativity.\n                If tuple, the first guide_scale will be used for low noise model and\n                the second guide_scale will be used for high noise model.\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n            init_first_frame (`bool`, *optional*, defaults to False):\n                Whether to use the reference image as the first frame (i.e., standard image-to-video generation)\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames (81)\n                - H: Frame height (from max_area)\n                - W: Frame width from max_area)\n        \"\"\"\n        # preprocess\n        size = self.get_gen_size(\n            size=None,\n            max_area=max_area,\n            ref_image_path=ref_image_path,\n            pre_video_path=None)\n        HEIGHT, WIDTH = size\n        channel = 3\n\n        resize_opreat = transforms.Resize(min(HEIGHT, WIDTH))\n        crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH))\n        tensor_trans = transforms.ToTensor()\n\n        ref_image = None\n        motion_latents = None\n\n        if ref_image is None:\n            ref_image = np.array(Image.open(ref_image_path).convert('RGB'))\n        if motion_latents is None:\n            motion_latents = torch.zeros(\n                [1, channel, self.motion_frames, HEIGHT, WIDTH],\n                dtype=self.param_dtype,\n                device=self.device)\n\n        # extract audio emb\n        if enable_tts is True:\n            audio_path = self.tts(tts_prompt_audio, tts_prompt_text, tts_text)\n        audio_emb, nr = self.encode_audio(audio_path, infer_frames=infer_frames)\n        if num_repeat is None or num_repeat > nr:\n            num_repeat = nr\n\n        lat_motion_frames = (self.motion_frames + 3) // 4\n        model_pic = crop_opreat(resize_opreat(Image.fromarray(ref_image)))\n\n        ref_pixel_values = tensor_trans(model_pic)\n        ref_pixel_values = ref_pixel_values.unsqueeze(1).unsqueeze(\n            0) * 2 - 1.0  # b c 1 h w\n        ref_pixel_values = ref_pixel_values.to(\n            dtype=self.vae.dtype, device=self.vae.device)\n        ref_latents = torch.stack(self.vae.encode(ref_pixel_values))\n\n        # encode the motion latents\n        videos_last_frames = motion_latents.detach()\n        drop_first_motion = self.drop_first_motion\n        if init_first_frame:\n            drop_first_motion = False\n            motion_latents[:, :, -6:] = ref_pixel_values\n        motion_latents = torch.stack(self.vae.encode(motion_latents))\n\n        # get pose cond input if need\n        COND = self.load_pose_cond(\n            pose_video=pose_video,\n            num_repeat=num_repeat,\n            infer_frames=infer_frames,\n            size=size)\n\n        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)\n\n        if n_prompt == \"\":\n            n_prompt = self.sample_neg_prompt\n\n        # preprocess\n        if not self.t5_cpu:\n            self.text_encoder.model.to(self.device)\n            context = self.text_encoder([input_prompt], self.device)\n            context_null = self.text_encoder([n_prompt], self.device)\n            if offload_model:\n                self.text_encoder.model.cpu()\n        else:\n            context = self.text_encoder([input_prompt], torch.device('cpu'))\n            context_null = self.text_encoder([n_prompt], torch.device('cpu'))\n            context = [t.to(self.device) for t in context]\n            context_null = [t.to(self.device) for t in context_null]\n\n        out = []\n        # evaluation mode\n        with (\n                torch.amp.autocast('cuda', dtype=self.param_dtype),\n                torch.no_grad(),\n        ):\n            for r in range(num_repeat):\n                seed_g = torch.Generator(device=self.device)\n                seed_g.manual_seed(seed + r)\n\n                lat_target_frames = (infer_frames + 3 + self.motion_frames\n                                    ) // 4 - lat_motion_frames\n                target_shape = [lat_target_frames, HEIGHT // 8, WIDTH // 8]\n                noise = [\n                    torch.randn(\n                        16,\n                        target_shape[0],\n                        target_shape[1],\n                        target_shape[2],\n                        dtype=self.param_dtype,\n                        device=self.device,\n                        generator=seed_g)\n                ]\n                max_seq_len = np.prod(target_shape) // 4\n\n                if sample_solver == 'unipc':\n                    sample_scheduler = FlowUniPCMultistepScheduler(\n                        num_train_timesteps=self.num_train_timesteps,\n                        shift=1,\n                        use_dynamic_shifting=False)\n                    sample_scheduler.set_timesteps(\n                        sampling_steps, device=self.device, shift=shift)\n                    timesteps = sample_scheduler.timesteps\n                elif sample_solver == 'dpm++':\n                    sample_scheduler = FlowDPMSolverMultistepScheduler(\n                        num_train_timesteps=self.num_train_timesteps,\n                        shift=1,\n                        use_dynamic_shifting=False)\n                    sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)\n                    timesteps, _ = retrieve_timesteps(\n                        sample_scheduler,\n                        device=self.device,\n                        sigmas=sampling_sigmas)\n                else:\n                    raise NotImplementedError(\"Unsupported solver.\")\n\n                latents = deepcopy(noise)\n                with torch.no_grad():\n                    left_idx = r * infer_frames\n                    right_idx = r * infer_frames + infer_frames\n                    cond_latents = COND[r] if pose_video else COND[0] * 0\n                    cond_latents = cond_latents.to(\n                        dtype=self.param_dtype, device=self.device)\n                    audio_input = audio_emb[..., left_idx:right_idx]\n                input_motion_latents = motion_latents.clone()\n\n                arg_c = {\n                    'context': context[0:1],\n                    'seq_len': max_seq_len,\n                    'cond_states': cond_latents,\n                    \"motion_latents\": input_motion_latents,\n                    'ref_latents': ref_latents,\n                    \"audio_input\": audio_input,\n                    \"motion_frames\": [self.motion_frames, lat_motion_frames],\n                    \"drop_motion_frames\": drop_first_motion and r == 0,\n                }\n                if guide_scale > 1:\n                    arg_null = {\n                        'context': context_null[0:1],\n                        'seq_len': max_seq_len,\n                        'cond_states': cond_latents,\n                        \"motion_latents\": input_motion_latents,\n                        'ref_latents': ref_latents,\n                        \"audio_input\": 0.0 * audio_input,\n                        \"motion_frames\": [\n                            self.motion_frames, lat_motion_frames\n                        ],\n                        \"drop_motion_frames\": drop_first_motion and r == 0,\n                    }\n                if offload_model or self.init_on_cpu:\n                    self.noise_model.to(self.device)\n                    torch.cuda.empty_cache()\n\n                for i, t in enumerate(tqdm(timesteps)):\n                    latent_model_input = latents[0:1]\n                    timestep = [t]\n\n                    timestep = torch.stack(timestep).to(self.device)\n\n                    noise_pred_cond = self.noise_model(\n                        latent_model_input, t=timestep, **arg_c)\n\n                    if guide_scale > 1:\n                        noise_pred_uncond = self.noise_model(\n                            latent_model_input, t=timestep, **arg_null)\n                        noise_pred = [\n                            u + guide_scale * (c - u)\n                            for c, u in zip(noise_pred_cond, noise_pred_uncond)\n                        ]\n                    else:\n                        noise_pred = noise_pred_cond\n\n                    temp_x0 = sample_scheduler.step(\n                        noise_pred[0].unsqueeze(0),\n                        t,\n                        latents[0].unsqueeze(0),\n                        return_dict=False,\n                        generator=seed_g)[0]\n                    latents[0] = temp_x0.squeeze(0)\n\n                if offload_model:\n                    self.noise_model.cpu()\n                    torch.cuda.synchronize()\n                    torch.cuda.empty_cache()\n                latents = torch.stack(latents)\n                if not (drop_first_motion and r == 0):\n                    decode_latents = torch.cat([motion_latents, latents], dim=2)\n                else:\n                    decode_latents = torch.cat([ref_latents, latents], dim=2)\n                image = torch.stack(self.vae.decode(decode_latents))\n                image = image[:, :, -(infer_frames):]\n                if (drop_first_motion and r == 0):\n                    image = image[:, :, 3:]\n\n                overlap_frames_num = min(self.motion_frames, image.shape[2])\n                videos_last_frames = torch.cat([\n                    videos_last_frames[:, :, overlap_frames_num:],\n                    image[:, :, -overlap_frames_num:]\n                ],\n                                               dim=2)\n                videos_last_frames = videos_last_frames.to(\n                    dtype=motion_latents.dtype, device=motion_latents.device)\n                motion_latents = torch.stack(\n                    self.vae.encode(videos_last_frames))\n                out.append(image.cpu())\n\n        videos = torch.cat(out, dim=2)\n        del noise, latents\n        del sample_scheduler\n        if offload_model:\n            gc.collect()\n            torch.cuda.synchronize()\n        if dist.is_initialized():\n            dist.barrier()\n\n        return videos[0] if self.rank == 0 else None\n\n    def tts(self, tts_prompt_audio, tts_prompt_text, tts_text):\n        if not hasattr(self, 'cosyvoice'):\n            self.load_tts()\n        speech_list = []\n        from cosyvoice.utils.file_utils import load_wav\n        import torchaudio\n        prompt_speech_16k = load_wav(tts_prompt_audio, 16000)\n        if tts_prompt_text is not None:\n            for i in self.cosyvoice.inference_zero_shot(tts_text, tts_prompt_text, prompt_speech_16k):\n                speech_list.append(i['tts_speech'])\n        else:\n            for i in self.cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k):\n                speech_list.append(i['tts_speech'])\n        torchaudio.save('tts.wav', torch.concat(speech_list, dim=1), self.cosyvoice.sample_rate)\n        return 'tts.wav'\n\n    def load_tts(self):\n        if not os.path.exists('CosyVoice'):\n            from wan.utils.utils import download_cosyvoice_repo\n            download_cosyvoice_repo('CosyVoice')\n        if not os.path.exists('CosyVoice2-0.5B'):\n            from wan.utils.utils import download_cosyvoice_model\n            download_cosyvoice_model('CosyVoice2-0.5B', 'CosyVoice2-0.5B')\n        sys.path.append('CosyVoice')\n        sys.path.append('CosyVoice/third_party/Matcha-TTS')\n        from cosyvoice.cli.cosyvoice import CosyVoice2\n        self.cosyvoice = CosyVoice2('CosyVoice2-0.5B')"
  },
  {
    "path": "wan/text2video.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport types\nfrom contextlib import contextmanager\nfrom functools import partial\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.distributed as dist\nfrom tqdm import tqdm\n\nfrom .distributed.fsdp import shard_model\nfrom .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward\nfrom .distributed.util import get_world_size\nfrom .modules.model import WanModel\nfrom .modules.t5 import T5EncoderModel\nfrom .modules.vae2_1 import Wan2_1_VAE\nfrom .utils.fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler\n\n\nclass WanT2V:\n\n    def __init__(\n        self,\n        config,\n        checkpoint_dir,\n        device_id=0,\n        rank=0,\n        t5_fsdp=False,\n        dit_fsdp=False,\n        use_sp=False,\n        t5_cpu=False,\n        init_on_cpu=True,\n        convert_model_dtype=False,\n    ):\n        r\"\"\"\n        Initializes the Wan text-to-video generation model components.\n\n        Args:\n            config (EasyDict):\n                Object containing model parameters initialized from config.py\n            checkpoint_dir (`str`):\n                Path to directory containing model checkpoints\n            device_id (`int`,  *optional*, defaults to 0):\n                Id of target GPU device\n            rank (`int`,  *optional*, defaults to 0):\n                Process rank for distributed training\n            t5_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for T5 model\n            dit_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for DiT model\n            use_sp (`bool`, *optional*, defaults to False):\n                Enable distribution strategy of sequence parallel.\n            t5_cpu (`bool`, *optional*, defaults to False):\n                Whether to place T5 model on CPU. Only works without t5_fsdp.\n            init_on_cpu (`bool`, *optional*, defaults to True):\n                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.\n            convert_model_dtype (`bool`, *optional*, defaults to False):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n        \"\"\"\n        self.device = torch.device(f\"cuda:{device_id}\")\n        self.config = config\n        self.rank = rank\n        self.t5_cpu = t5_cpu\n        self.init_on_cpu = init_on_cpu\n\n        self.num_train_timesteps = config.num_train_timesteps\n        self.boundary = config.boundary\n        self.param_dtype = config.param_dtype\n\n        if t5_fsdp or dit_fsdp or use_sp:\n            self.init_on_cpu = False\n\n        shard_fn = partial(shard_model, device_id=device_id)\n        self.text_encoder = T5EncoderModel(\n            text_len=config.text_len,\n            dtype=config.t5_dtype,\n            device=torch.device('cpu'),\n            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),\n            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),\n            shard_fn=shard_fn if t5_fsdp else None)\n\n        self.vae_stride = config.vae_stride\n        self.patch_size = config.patch_size\n        self.vae = Wan2_1_VAE(\n            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),\n            device=self.device)\n\n        logging.info(f\"Creating WanModel from {checkpoint_dir}\")\n        self.low_noise_model = WanModel.from_pretrained(\n            checkpoint_dir, subfolder=config.low_noise_checkpoint)\n        self.low_noise_model = self._configure_model(\n            model=self.low_noise_model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype)\n\n        self.high_noise_model = WanModel.from_pretrained(\n            checkpoint_dir, subfolder=config.high_noise_checkpoint)\n        self.high_noise_model = self._configure_model(\n            model=self.high_noise_model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype)\n        if use_sp:\n            self.sp_size = get_world_size()\n        else:\n            self.sp_size = 1\n\n        self.sample_neg_prompt = config.sample_neg_prompt\n\n    def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,\n                         convert_model_dtype):\n        \"\"\"\n        Configures a model object. This includes setting evaluation modes,\n        applying distributed parallel strategy, and handling device placement.\n\n        Args:\n            model (torch.nn.Module):\n                The model instance to configure.\n            use_sp (`bool`):\n                Enable distribution strategy of sequence parallel.\n            dit_fsdp (`bool`):\n                Enable FSDP sharding for DiT model.\n            shard_fn (callable):\n                The function to apply FSDP sharding.\n            convert_model_dtype (`bool`):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n\n        Returns:\n            torch.nn.Module:\n                The configured model.\n        \"\"\"\n        model.eval().requires_grad_(False)\n\n        if use_sp:\n            for block in model.blocks:\n                block.self_attn.forward = types.MethodType(\n                    sp_attn_forward, block.self_attn)\n            model.forward = types.MethodType(sp_dit_forward, model)\n\n        if dist.is_initialized():\n            dist.barrier()\n\n        if dit_fsdp:\n            model = shard_fn(model)\n        else:\n            if convert_model_dtype:\n                model.to(self.param_dtype)\n            if not self.init_on_cpu:\n                model.to(self.device)\n\n        return model\n\n    def _prepare_model_for_timestep(self, t, boundary, offload_model):\n        r\"\"\"\n        Prepares and returns the required model for the current timestep.\n\n        Args:\n            t (torch.Tensor):\n                current timestep.\n            boundary (`int`):\n                The timestep threshold. If `t` is at or above this value,\n                the `high_noise_model` is considered as the required model.\n            offload_model (`bool`):\n                A flag intended to control the offloading behavior.\n\n        Returns:\n            torch.nn.Module:\n                The active model on the target device for the current timestep.\n        \"\"\"\n        if t.item() >= boundary:\n            required_model_name = 'high_noise_model'\n            offload_model_name = 'low_noise_model'\n        else:\n            required_model_name = 'low_noise_model'\n            offload_model_name = 'high_noise_model'\n        if offload_model or self.init_on_cpu:\n            if next(getattr(\n                    self,\n                    offload_model_name).parameters()).device.type == 'cuda':\n                getattr(self, offload_model_name).to('cpu')\n            if next(getattr(\n                    self,\n                    required_model_name).parameters()).device.type == 'cpu':\n                getattr(self, required_model_name).to(self.device)\n        return getattr(self, required_model_name)\n\n    def generate(self,\n                 input_prompt,\n                 size=(1280, 720),\n                 frame_num=81,\n                 shift=5.0,\n                 sample_solver='unipc',\n                 sampling_steps=50,\n                 guide_scale=5.0,\n                 n_prompt=\"\",\n                 seed=-1,\n                 offload_model=True):\n        r\"\"\"\n        Generates video frames from text prompt using diffusion process.\n\n        Args:\n            input_prompt (`str`):\n                Text prompt for content generation\n            size (`tuple[int]`, *optional*, defaults to (1280,720)):\n                Controls video resolution, (width,height).\n            frame_num (`int`, *optional*, defaults to 81):\n                How many frames to sample from a video. The number should be 4n+1\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. Affects temporal dynamics\n            sample_solver (`str`, *optional*, defaults to 'unipc'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 50):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):\n                Classifier-free guidance scale. Controls prompt adherence vs. creativity.\n                If tuple, the first guide_scale will be used for low noise model and\n                the second guide_scale will be used for high noise model.\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed.\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames (81)\n                - H: Frame height (from size)\n                - W: Frame width from size)\n        \"\"\"\n        # preprocess\n        guide_scale = (guide_scale, guide_scale) if isinstance(\n            guide_scale, float) else guide_scale\n        F = frame_num\n        target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,\n                        size[1] // self.vae_stride[1],\n                        size[0] // self.vae_stride[2])\n\n        seq_len = math.ceil((target_shape[2] * target_shape[3]) /\n                            (self.patch_size[1] * self.patch_size[2]) *\n                            target_shape[1] / self.sp_size) * self.sp_size\n\n        if n_prompt == \"\":\n            n_prompt = self.sample_neg_prompt\n        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)\n        seed_g = torch.Generator(device=self.device)\n        seed_g.manual_seed(seed)\n\n        if not self.t5_cpu:\n            self.text_encoder.model.to(self.device)\n            context = self.text_encoder([input_prompt], self.device)\n            context_null = self.text_encoder([n_prompt], self.device)\n            if offload_model:\n                self.text_encoder.model.cpu()\n        else:\n            context = self.text_encoder([input_prompt], torch.device('cpu'))\n            context_null = self.text_encoder([n_prompt], torch.device('cpu'))\n            context = [t.to(self.device) for t in context]\n            context_null = [t.to(self.device) for t in context_null]\n\n        noise = [\n            torch.randn(\n                target_shape[0],\n                target_shape[1],\n                target_shape[2],\n                target_shape[3],\n                dtype=torch.float32,\n                device=self.device,\n                generator=seed_g)\n        ]\n\n        @contextmanager\n        def noop_no_sync():\n            yield\n\n        no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',\n                                    noop_no_sync)\n        no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',\n                                     noop_no_sync)\n\n        # evaluation mode\n        with (\n                torch.amp.autocast('cuda', dtype=self.param_dtype),\n                torch.no_grad(),\n                no_sync_low_noise(),\n                no_sync_high_noise(),\n        ):\n            boundary = self.boundary * self.num_train_timesteps\n\n            if sample_solver == 'unipc':\n                sample_scheduler = FlowUniPCMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sample_scheduler.set_timesteps(\n                    sampling_steps, device=self.device, shift=shift)\n                timesteps = sample_scheduler.timesteps\n            elif sample_solver == 'dpm++':\n                sample_scheduler = FlowDPMSolverMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)\n                timesteps, _ = retrieve_timesteps(\n                    sample_scheduler,\n                    device=self.device,\n                    sigmas=sampling_sigmas)\n            else:\n                raise NotImplementedError(\"Unsupported solver.\")\n\n            # sample videos\n            latents = noise\n\n            arg_c = {'context': context, 'seq_len': seq_len}\n            arg_null = {'context': context_null, 'seq_len': seq_len}\n\n            for _, t in enumerate(tqdm(timesteps)):\n                latent_model_input = latents\n                timestep = [t]\n\n                timestep = torch.stack(timestep)\n\n                model = self._prepare_model_for_timestep(\n                    t, boundary, offload_model)\n                sample_guide_scale = guide_scale[1] if t.item(\n                ) >= boundary else guide_scale[0]\n\n                noise_pred_cond = model(\n                    latent_model_input, t=timestep, **arg_c)[0]\n                noise_pred_uncond = model(\n                    latent_model_input, t=timestep, **arg_null)[0]\n\n                noise_pred = noise_pred_uncond + sample_guide_scale * (\n                    noise_pred_cond - noise_pred_uncond)\n\n                temp_x0 = sample_scheduler.step(\n                    noise_pred.unsqueeze(0),\n                    t,\n                    latents[0].unsqueeze(0),\n                    return_dict=False,\n                    generator=seed_g)[0]\n                latents = [temp_x0.squeeze(0)]\n\n            x0 = latents\n            if offload_model:\n                self.low_noise_model.cpu()\n                self.high_noise_model.cpu()\n                torch.cuda.empty_cache()\n            if self.rank == 0:\n                videos = self.vae.decode(x0)\n\n        del noise, latents\n        del sample_scheduler\n        if offload_model:\n            gc.collect()\n            torch.cuda.synchronize()\n        if dist.is_initialized():\n            dist.barrier()\n\n        return videos[0] if self.rank == 0 else None\n"
  },
  {
    "path": "wan/textimage2video.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport types\nfrom contextlib import contextmanager\nfrom functools import partial\n\nimport torch\nimport torch.cuda.amp as amp\nimport torch.distributed as dist\nimport torchvision.transforms.functional as TF\nfrom PIL import Image\nfrom tqdm import tqdm\n\nfrom .distributed.fsdp import shard_model\nfrom .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward\nfrom .distributed.util import get_world_size\nfrom .modules.model import WanModel\nfrom .modules.t5 import T5EncoderModel\nfrom .modules.vae2_2 import Wan2_2_VAE\nfrom .utils.fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler\nfrom .utils.utils import best_output_size, masks_like\n\n\nclass WanTI2V:\n\n    def __init__(\n        self,\n        config,\n        checkpoint_dir,\n        device_id=0,\n        rank=0,\n        t5_fsdp=False,\n        dit_fsdp=False,\n        use_sp=False,\n        t5_cpu=False,\n        init_on_cpu=True,\n        convert_model_dtype=False,\n    ):\n        r\"\"\"\n        Initializes the Wan text-to-video generation model components.\n\n        Args:\n            config (EasyDict):\n                Object containing model parameters initialized from config.py\n            checkpoint_dir (`str`):\n                Path to directory containing model checkpoints\n            device_id (`int`,  *optional*, defaults to 0):\n                Id of target GPU device\n            rank (`int`,  *optional*, defaults to 0):\n                Process rank for distributed training\n            t5_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for T5 model\n            dit_fsdp (`bool`, *optional*, defaults to False):\n                Enable FSDP sharding for DiT model\n            use_sp (`bool`, *optional*, defaults to False):\n                Enable distribution strategy of sequence parallel.\n            t5_cpu (`bool`, *optional*, defaults to False):\n                Whether to place T5 model on CPU. Only works without t5_fsdp.\n            init_on_cpu (`bool`, *optional*, defaults to True):\n                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.\n            convert_model_dtype (`bool`, *optional*, defaults to False):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n        \"\"\"\n        self.device = torch.device(f\"cuda:{device_id}\")\n        self.config = config\n        self.rank = rank\n        self.t5_cpu = t5_cpu\n        self.init_on_cpu = init_on_cpu\n\n        self.num_train_timesteps = config.num_train_timesteps\n        self.param_dtype = config.param_dtype\n\n        if t5_fsdp or dit_fsdp or use_sp:\n            self.init_on_cpu = False\n\n        shard_fn = partial(shard_model, device_id=device_id)\n        self.text_encoder = T5EncoderModel(\n            text_len=config.text_len,\n            dtype=config.t5_dtype,\n            device=torch.device('cpu'),\n            checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),\n            tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),\n            shard_fn=shard_fn if t5_fsdp else None)\n\n        self.vae_stride = config.vae_stride\n        self.patch_size = config.patch_size\n        self.vae = Wan2_2_VAE(\n            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),\n            device=self.device)\n\n        logging.info(f\"Creating WanModel from {checkpoint_dir}\")\n        self.model = WanModel.from_pretrained(checkpoint_dir)\n        self.model = self._configure_model(\n            model=self.model,\n            use_sp=use_sp,\n            dit_fsdp=dit_fsdp,\n            shard_fn=shard_fn,\n            convert_model_dtype=convert_model_dtype)\n\n        if use_sp:\n            self.sp_size = get_world_size()\n        else:\n            self.sp_size = 1\n\n        self.sample_neg_prompt = config.sample_neg_prompt\n\n    def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,\n                         convert_model_dtype):\n        \"\"\"\n        Configures a model object. This includes setting evaluation modes,\n        applying distributed parallel strategy, and handling device placement.\n\n        Args:\n            model (torch.nn.Module):\n                The model instance to configure.\n            use_sp (`bool`):\n                Enable distribution strategy of sequence parallel.\n            dit_fsdp (`bool`):\n                Enable FSDP sharding for DiT model.\n            shard_fn (callable):\n                The function to apply FSDP sharding.\n            convert_model_dtype (`bool`):\n                Convert DiT model parameters dtype to 'config.param_dtype'.\n                Only works without FSDP.\n\n        Returns:\n            torch.nn.Module:\n                The configured model.\n        \"\"\"\n        model.eval().requires_grad_(False)\n\n        if use_sp:\n            for block in model.blocks:\n                block.self_attn.forward = types.MethodType(\n                    sp_attn_forward, block.self_attn)\n            model.forward = types.MethodType(sp_dit_forward, model)\n\n        if dist.is_initialized():\n            dist.barrier()\n\n        if dit_fsdp:\n            model = shard_fn(model)\n        else:\n            if convert_model_dtype:\n                model.to(self.param_dtype)\n            if not self.init_on_cpu:\n                model.to(self.device)\n\n        return model\n\n    def generate(self,\n                 input_prompt,\n                 img=None,\n                 size=(1280, 704),\n                 max_area=704 * 1280,\n                 frame_num=81,\n                 shift=5.0,\n                 sample_solver='unipc',\n                 sampling_steps=50,\n                 guide_scale=5.0,\n                 n_prompt=\"\",\n                 seed=-1,\n                 offload_model=True):\n        r\"\"\"\n        Generates video frames from text prompt using diffusion process.\n\n        Args:\n            input_prompt (`str`):\n                Text prompt for content generation\n            img (PIL.Image.Image):\n                Input image tensor. Shape: [3, H, W]\n            size (`tuple[int]`, *optional*, defaults to (1280,704)):\n                Controls video resolution, (width,height).\n            max_area (`int`, *optional*, defaults to 704*1280):\n                Maximum pixel area for latent space calculation. Controls video resolution scaling\n            frame_num (`int`, *optional*, defaults to 81):\n                How many frames to sample from a video. The number should be 4n+1\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. Affects temporal dynamics\n            sample_solver (`str`, *optional*, defaults to 'unipc'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 50):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float`, *optional*, defaults 5.0):\n                Classifier-free guidance scale. Controls prompt adherence vs. creativity.\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed.\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames (81)\n                - H: Frame height (from size)\n                - W: Frame width from size)\n        \"\"\"\n        # i2v\n        if img is not None:\n            return self.i2v(\n                input_prompt=input_prompt,\n                img=img,\n                max_area=max_area,\n                frame_num=frame_num,\n                shift=shift,\n                sample_solver=sample_solver,\n                sampling_steps=sampling_steps,\n                guide_scale=guide_scale,\n                n_prompt=n_prompt,\n                seed=seed,\n                offload_model=offload_model)\n        # t2v\n        return self.t2v(\n            input_prompt=input_prompt,\n            size=size,\n            frame_num=frame_num,\n            shift=shift,\n            sample_solver=sample_solver,\n            sampling_steps=sampling_steps,\n            guide_scale=guide_scale,\n            n_prompt=n_prompt,\n            seed=seed,\n            offload_model=offload_model)\n\n    def t2v(self,\n            input_prompt,\n            size=(1280, 704),\n            frame_num=121,\n            shift=5.0,\n            sample_solver='unipc',\n            sampling_steps=50,\n            guide_scale=5.0,\n            n_prompt=\"\",\n            seed=-1,\n            offload_model=True):\n        r\"\"\"\n        Generates video frames from text prompt using diffusion process.\n\n        Args:\n            input_prompt (`str`):\n                Text prompt for content generation\n            size (`tuple[int]`, *optional*, defaults to (1280,704)):\n                Controls video resolution, (width,height).\n            frame_num (`int`, *optional*, defaults to 121):\n                How many frames to sample from a video. The number should be 4n+1\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. Affects temporal dynamics\n            sample_solver (`str`, *optional*, defaults to 'unipc'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 50):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float`, *optional*, defaults 5.0):\n                Classifier-free guidance scale. Controls prompt adherence vs. creativity.\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed.\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames (81)\n                - H: Frame height (from size)\n                - W: Frame width from size)\n        \"\"\"\n        # preprocess\n        F = frame_num\n        target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,\n                        size[1] // self.vae_stride[1],\n                        size[0] // self.vae_stride[2])\n\n        seq_len = math.ceil((target_shape[2] * target_shape[3]) /\n                            (self.patch_size[1] * self.patch_size[2]) *\n                            target_shape[1] / self.sp_size) * self.sp_size\n\n        if n_prompt == \"\":\n            n_prompt = self.sample_neg_prompt\n        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)\n        seed_g = torch.Generator(device=self.device)\n        seed_g.manual_seed(seed)\n\n        if not self.t5_cpu:\n            self.text_encoder.model.to(self.device)\n            context = self.text_encoder([input_prompt], self.device)\n            context_null = self.text_encoder([n_prompt], self.device)\n            if offload_model:\n                self.text_encoder.model.cpu()\n        else:\n            context = self.text_encoder([input_prompt], torch.device('cpu'))\n            context_null = self.text_encoder([n_prompt], torch.device('cpu'))\n            context = [t.to(self.device) for t in context]\n            context_null = [t.to(self.device) for t in context_null]\n\n        noise = [\n            torch.randn(\n                target_shape[0],\n                target_shape[1],\n                target_shape[2],\n                target_shape[3],\n                dtype=torch.float32,\n                device=self.device,\n                generator=seed_g)\n        ]\n\n        @contextmanager\n        def noop_no_sync():\n            yield\n\n        no_sync = getattr(self.model, 'no_sync', noop_no_sync)\n\n        # evaluation mode\n        with (\n                torch.amp.autocast('cuda', dtype=self.param_dtype),\n                torch.no_grad(),\n                no_sync(),\n        ):\n\n            if sample_solver == 'unipc':\n                sample_scheduler = FlowUniPCMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sample_scheduler.set_timesteps(\n                    sampling_steps, device=self.device, shift=shift)\n                timesteps = sample_scheduler.timesteps\n            elif sample_solver == 'dpm++':\n                sample_scheduler = FlowDPMSolverMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)\n                timesteps, _ = retrieve_timesteps(\n                    sample_scheduler,\n                    device=self.device,\n                    sigmas=sampling_sigmas)\n            else:\n                raise NotImplementedError(\"Unsupported solver.\")\n\n            # sample videos\n            latents = noise\n            mask1, mask2 = masks_like(noise, zero=False)\n\n            arg_c = {'context': context, 'seq_len': seq_len}\n            arg_null = {'context': context_null, 'seq_len': seq_len}\n\n            if offload_model or self.init_on_cpu:\n                self.model.to(self.device)\n                torch.cuda.empty_cache()\n\n            for _, t in enumerate(tqdm(timesteps)):\n                latent_model_input = latents\n                timestep = [t]\n\n                timestep = torch.stack(timestep)\n\n                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()\n                temp_ts = torch.cat([\n                    temp_ts,\n                    temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep\n                ])\n                timestep = temp_ts.unsqueeze(0)\n\n                noise_pred_cond = self.model(\n                    latent_model_input, t=timestep, **arg_c)[0]\n                noise_pred_uncond = self.model(\n                    latent_model_input, t=timestep, **arg_null)[0]\n\n                noise_pred = noise_pred_uncond + guide_scale * (\n                    noise_pred_cond - noise_pred_uncond)\n\n                temp_x0 = sample_scheduler.step(\n                    noise_pred.unsqueeze(0),\n                    t,\n                    latents[0].unsqueeze(0),\n                    return_dict=False,\n                    generator=seed_g)[0]\n                latents = [temp_x0.squeeze(0)]\n            x0 = latents\n            if offload_model:\n                self.model.cpu()\n                torch.cuda.synchronize()\n                torch.cuda.empty_cache()\n            if self.rank == 0:\n                videos = self.vae.decode(x0)\n\n        del noise, latents\n        del sample_scheduler\n        if offload_model:\n            gc.collect()\n            torch.cuda.synchronize()\n        if dist.is_initialized():\n            dist.barrier()\n\n        return videos[0] if self.rank == 0 else None\n\n    def i2v(self,\n            input_prompt,\n            img,\n            max_area=704 * 1280,\n            frame_num=121,\n            shift=5.0,\n            sample_solver='unipc',\n            sampling_steps=40,\n            guide_scale=5.0,\n            n_prompt=\"\",\n            seed=-1,\n            offload_model=True):\n        r\"\"\"\n        Generates video frames from input image and text prompt using diffusion process.\n\n        Args:\n            input_prompt (`str`):\n                Text prompt for content generation.\n            img (PIL.Image.Image):\n                Input image tensor. Shape: [3, H, W]\n            max_area (`int`, *optional*, defaults to 704*1280):\n                Maximum pixel area for latent space calculation. Controls video resolution scaling\n            frame_num (`int`, *optional*, defaults to 121):\n                How many frames to sample from a video. The number should be 4n+1\n            shift (`float`, *optional*, defaults to 5.0):\n                Noise schedule shift parameter. Affects temporal dynamics\n                [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.\n            sample_solver (`str`, *optional*, defaults to 'unipc'):\n                Solver used to sample the video.\n            sampling_steps (`int`, *optional*, defaults to 40):\n                Number of diffusion sampling steps. Higher values improve quality but slow generation\n            guide_scale (`float`, *optional*, defaults 5.0):\n                Classifier-free guidance scale. Controls prompt adherence vs. creativity.\n            n_prompt (`str`, *optional*, defaults to \"\"):\n                Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`\n            seed (`int`, *optional*, defaults to -1):\n                Random seed for noise generation. If -1, use random seed\n            offload_model (`bool`, *optional*, defaults to True):\n                If True, offloads models to CPU during generation to save VRAM\n\n        Returns:\n            torch.Tensor:\n                Generated video frames tensor. Dimensions: (C, N H, W) where:\n                - C: Color channels (3 for RGB)\n                - N: Number of frames (121)\n                - H: Frame height (from max_area)\n                - W: Frame width (from max_area)\n        \"\"\"\n        # preprocess\n        ih, iw = img.height, img.width\n        dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[\n            2] * self.vae_stride[2]\n        ow, oh = best_output_size(iw, ih, dw, dh, max_area)\n\n        scale = max(ow / iw, oh / ih)\n        img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)\n\n        # center-crop\n        x1 = (img.width - ow) // 2\n        y1 = (img.height - oh) // 2\n        img = img.crop((x1, y1, x1 + ow, y1 + oh))\n        assert img.width == ow and img.height == oh\n\n        # to tensor\n        img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)\n\n        F = frame_num\n        seq_len = ((F - 1) // self.vae_stride[0] + 1) * (\n            oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // (\n                self.patch_size[1] * self.patch_size[2])\n        seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size\n\n        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)\n        seed_g = torch.Generator(device=self.device)\n        seed_g.manual_seed(seed)\n        noise = torch.randn(\n            self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,\n            oh // self.vae_stride[1],\n            ow // self.vae_stride[2],\n            dtype=torch.float32,\n            generator=seed_g,\n            device=self.device)\n\n        if n_prompt == \"\":\n            n_prompt = self.sample_neg_prompt\n\n        # preprocess\n        if not self.t5_cpu:\n            self.text_encoder.model.to(self.device)\n            context = self.text_encoder([input_prompt], self.device)\n            context_null = self.text_encoder([n_prompt], self.device)\n            if offload_model:\n                self.text_encoder.model.cpu()\n        else:\n            context = self.text_encoder([input_prompt], torch.device('cpu'))\n            context_null = self.text_encoder([n_prompt], torch.device('cpu'))\n            context = [t.to(self.device) for t in context]\n            context_null = [t.to(self.device) for t in context_null]\n\n        z = self.vae.encode([img])\n\n        @contextmanager\n        def noop_no_sync():\n            yield\n\n        no_sync = getattr(self.model, 'no_sync', noop_no_sync)\n\n        # evaluation mode\n        with (\n                torch.amp.autocast('cuda', dtype=self.param_dtype),\n                torch.no_grad(),\n                no_sync(),\n        ):\n\n            if sample_solver == 'unipc':\n                sample_scheduler = FlowUniPCMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sample_scheduler.set_timesteps(\n                    sampling_steps, device=self.device, shift=shift)\n                timesteps = sample_scheduler.timesteps\n            elif sample_solver == 'dpm++':\n                sample_scheduler = FlowDPMSolverMultistepScheduler(\n                    num_train_timesteps=self.num_train_timesteps,\n                    shift=1,\n                    use_dynamic_shifting=False)\n                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)\n                timesteps, _ = retrieve_timesteps(\n                    sample_scheduler,\n                    device=self.device,\n                    sigmas=sampling_sigmas)\n            else:\n                raise NotImplementedError(\"Unsupported solver.\")\n\n            # sample videos\n            latent = noise\n            mask1, mask2 = masks_like([noise], zero=True)\n            latent = (1. - mask2[0]) * z[0] + mask2[0] * latent\n\n            arg_c = {\n                'context': [context[0]],\n                'seq_len': seq_len,\n            }\n\n            arg_null = {\n                'context': context_null,\n                'seq_len': seq_len,\n            }\n\n            if offload_model or self.init_on_cpu:\n                self.model.to(self.device)\n                torch.cuda.empty_cache()\n\n            for _, t in enumerate(tqdm(timesteps)):\n                latent_model_input = [latent.to(self.device)]\n                timestep = [t]\n\n                timestep = torch.stack(timestep).to(self.device)\n\n                temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()\n                temp_ts = torch.cat([\n                    temp_ts,\n                    temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep\n                ])\n                timestep = temp_ts.unsqueeze(0)\n\n                noise_pred_cond = self.model(\n                    latent_model_input, t=timestep, **arg_c)[0]\n                if offload_model:\n                    torch.cuda.empty_cache()\n                noise_pred_uncond = self.model(\n                    latent_model_input, t=timestep, **arg_null)[0]\n                if offload_model:\n                    torch.cuda.empty_cache()\n                noise_pred = noise_pred_uncond + guide_scale * (\n                    noise_pred_cond - noise_pred_uncond)\n\n                temp_x0 = sample_scheduler.step(\n                    noise_pred.unsqueeze(0),\n                    t,\n                    latent.unsqueeze(0),\n                    return_dict=False,\n                    generator=seed_g)[0]\n                latent = temp_x0.squeeze(0)\n                latent = (1. - mask2[0]) * z[0] + mask2[0] * latent\n\n                x0 = [latent]\n                del latent_model_input, timestep\n\n            if offload_model:\n                self.model.cpu()\n                torch.cuda.synchronize()\n                torch.cuda.empty_cache()\n\n            if self.rank == 0:\n                videos = self.vae.decode(x0)\n\n        del noise, latent, x0\n        del sample_scheduler\n        if offload_model:\n            gc.collect()\n            torch.cuda.synchronize()\n        if dist.is_initialized():\n            dist.barrier()\n\n        return videos[0] if self.rank == 0 else None\n"
  },
  {
    "path": "wan/utils/__init__.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .fm_solvers import (\n    FlowDPMSolverMultistepScheduler,\n    get_sampling_sigmas,\n    retrieve_timesteps,\n)\nfrom .fm_solvers_unipc import FlowUniPCMultistepScheduler\n\n__all__ = [\n    'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',\n    'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'\n]\n"
  },
  {
    "path": "wan/utils/fm_solvers.py",
    "content": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py\n# Convert dpm solver for flow matching\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n\nimport inspect\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import (\n    KarrasDiffusionSchedulers,\n    SchedulerMixin,\n    SchedulerOutput,\n)\nfrom diffusers.utils import deprecate, is_scipy_available\nfrom diffusers.utils.torch_utils import randn_tensor\n\nif is_scipy_available():\n    pass\n\n\ndef get_sampling_sigmas(sampling_steps, shift):\n    sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]\n    sigma = (shift * sigma / (1 + (shift - 1) * sigma))\n\n    return sigma\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps=None,\n    device=None,\n    timesteps=None,\n    sigmas=None,\n    **kwargs,\n):\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\n            \"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\"\n        )\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(\n            inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(\n            inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model. This determines the resolution of the diffusion process.\n        solver_order (`int`, defaults to 2):\n            The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided\n            sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored\n            and used in multistep updates.\n        prediction_type (`str`, defaults to \"flow_prediction\"):\n            Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts\n            the flow of the diffusion process.\n        shift (`float`, *optional*, defaults to 1.0):\n            A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling\n            process.\n        use_dynamic_shifting (`bool`, defaults to `False`):\n            Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is\n            applied on the fly.\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This method adjusts the predicted sample to prevent\n            saturation and improve photorealism.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and\n            `algorithm_type=\"dpmsolver++\"`.\n        algorithm_type (`str`, defaults to `dpmsolver++`):\n            Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The\n            `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)\n            paper, and the `dpmsolver++` type implements the algorithms in the\n            [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or\n            `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.\n        solver_type (`str`, defaults to `midpoint`):\n            Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the\n            sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.\n        lower_order_final (`bool`, defaults to `True`):\n            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can\n            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.\n        euler_at_final (`bool`, defaults to `False`):\n            Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail\n            richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference\n            steps, but sometimes may result in blurring.\n        final_sigmas_type (`str`, *optional*, defaults to \"zero\"):\n            The final `sigma` value for the noise schedule during the sampling process. If `\"sigma_min\"`, the final\n            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.\n        lambda_min_clipped (`float`, defaults to `-inf`):\n            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the\n            cosine (`squaredcos_cap_v2`) noise schedule.\n        variance_type (`str`, *optional*):\n            Set to \"learned\" or \"learned_range\" for diffusion models that predict variance. If set, the model's output\n            contains the predicted Gaussian variance.\n    \"\"\"\n\n    _compatibles = [e.name for e in KarrasDiffusionSchedulers]\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        solver_order: int = 2,\n        prediction_type: str = \"flow_prediction\",\n        shift: Optional[float] = 1.0,\n        use_dynamic_shifting=False,\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        sample_max_value: float = 1.0,\n        algorithm_type: str = \"dpmsolver++\",\n        solver_type: str = \"midpoint\",\n        lower_order_final: bool = True,\n        euler_at_final: bool = False,\n        final_sigmas_type: Optional[str] = \"zero\",  # \"zero\", \"sigma_min\"\n        lambda_min_clipped: float = -float(\"inf\"),\n        variance_type: Optional[str] = None,\n        invert_sigmas: bool = False,\n    ):\n        if algorithm_type in [\"dpmsolver\", \"sde-dpmsolver\"]:\n            deprecation_message = f\"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead\"\n            deprecate(\"algorithm_types dpmsolver and sde-dpmsolver\", \"1.0.0\",\n                      deprecation_message)\n\n        # settings for DPM-Solver\n        if algorithm_type not in [\n                \"dpmsolver\", \"dpmsolver++\", \"sde-dpmsolver\", \"sde-dpmsolver++\"\n        ]:\n            if algorithm_type == \"deis\":\n                self.register_to_config(algorithm_type=\"dpmsolver++\")\n            else:\n                raise NotImplementedError(\n                    f\"{algorithm_type} is not implemented for {self.__class__}\")\n\n        if solver_type not in [\"midpoint\", \"heun\"]:\n            if solver_type in [\"logrho\", \"bh1\", \"bh2\"]:\n                self.register_to_config(solver_type=\"midpoint\")\n            else:\n                raise NotImplementedError(\n                    f\"{solver_type} is not implemented for {self.__class__}\")\n\n        if algorithm_type not in [\"dpmsolver++\", \"sde-dpmsolver++\"\n                                 ] and final_sigmas_type == \"zero\":\n            raise ValueError(\n                f\"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead.\"\n            )\n\n        # setable values\n        self.num_inference_steps = None\n        alphas = np.linspace(1, 1 / num_train_timesteps,\n                             num_train_timesteps)[::-1].copy()\n        sigmas = 1.0 - alphas\n        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)\n\n        if not use_dynamic_shifting:\n            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution\n            sigmas = shift * sigmas / (1 +\n                                       (shift - 1) * sigmas)  # pyright: ignore\n\n        self.sigmas = sigmas\n        self.timesteps = sigmas * num_train_timesteps\n\n        self.model_outputs = [None] * solver_order\n        self.lower_order_nums = 0\n        self._step_index = None\n        self._begin_index = None\n\n        # self.sigmas = self.sigmas.to(\n        #     \"cpu\")  # to avoid too much CPU/GPU communication\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n\n    @property\n    def step_index(self):\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps\n    def set_timesteps(\n        self,\n        num_inference_steps: Union[int, None] = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n        mu: Optional[Union[float, None]] = None,\n        shift: Optional[Union[float, None]] = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n        Args:\n            num_inference_steps (`int`):\n                Total number of the spacing of the time steps.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        \"\"\"\n\n        if self.config.use_dynamic_shifting and mu is None:\n            raise ValueError(\n                \" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`\"\n            )\n\n        if sigmas is None:\n            sigmas = np.linspace(self.sigma_max, self.sigma_min,\n                                 num_inference_steps +\n                                 1).copy()[:-1]  # pyright: ignore\n\n        if self.config.use_dynamic_shifting:\n            sigmas = self.time_shift(mu, 1.0, sigmas)  # pyright: ignore\n        else:\n            if shift is None:\n                shift = self.config.shift\n            sigmas = shift * sigmas / (1 +\n                                       (shift - 1) * sigmas)  # pyright: ignore\n\n        if self.config.final_sigmas_type == \"sigma_min\":\n            sigma_last = ((1 - self.alphas_cumprod[0]) /\n                          self.alphas_cumprod[0])**0.5\n        elif self.config.final_sigmas_type == \"zero\":\n            sigma_last = 0\n        else:\n            raise ValueError(\n                f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n            )\n\n        timesteps = sigmas * self.config.num_train_timesteps\n        sigmas = np.concatenate([sigmas, [sigma_last]\n                                ]).astype(np.float32)  # pyright: ignore\n\n        self.sigmas = torch.from_numpy(sigmas)\n        self.timesteps = torch.from_numpy(timesteps).to(\n            device=device, dtype=torch.int64)\n\n        self.num_inference_steps = len(timesteps)\n\n        self.model_outputs = [\n            None,\n        ] * self.config.solver_order\n        self.lower_order_nums = 0\n\n        self._step_index = None\n        self._begin_index = None\n        # self.sigmas = self.sigmas.to(\n        #     \"cpu\")  # to avoid too much CPU/GPU communication\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n        https://arxiv.org/abs/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, *remaining_dims = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = sample.float(\n            )  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(\n            abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n        s = s.unsqueeze(\n            1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = torch.clamp(\n            sample, -s, s\n        ) / s  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, *remaining_dims)\n        sample = sample.to(dtype)\n\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t\n    def _sigma_to_t(self, sigma):\n        return sigma * self.config.num_train_timesteps\n\n    def _sigma_to_alpha_sigma_t(self, sigma):\n        return 1 - sigma, sigma\n\n    # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps\n    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output\n    def convert_model_output(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is\n        designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an\n        integral of the data prediction model.\n        <Tip>\n        The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise\n        prediction and data prediction models.\n        </Tip>\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n        Returns:\n            `torch.Tensor`:\n                The converted model output.\n        \"\"\"\n        timestep = args[0] if len(args) > 0 else kwargs.pop(\"timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\n                    \"missing `sample` as a required keyward argument\")\n        if timestep is not None:\n            deprecate(\n                \"timesteps\",\n                \"1.0.0\",\n                \"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        # DPM-Solver++ needs to solve an integral of the data prediction model.\n        if self.config.algorithm_type in [\"dpmsolver++\", \"sde-dpmsolver++\"]:\n            if self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                    \" `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                x0_pred = self._threshold_sample(x0_pred)\n\n            return x0_pred\n\n        # DPM-Solver needs to solve an integral of the noise prediction model.\n        elif self.config.algorithm_type in [\"dpmsolver\", \"sde-dpmsolver\"]:\n            if self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                epsilon = sample - (1 - sigma_t) * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                    \" `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n                x0_pred = self._threshold_sample(x0_pred)\n                epsilon = model_output + x0_pred\n\n            return epsilon\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update\n    def dpm_solver_first_order_update(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        noise: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the first-order DPMSolver (equivalent to DDIM).\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n        Returns:\n            `torch.Tensor`:\n                The sample tensor at the previous timestep.\n        \"\"\"\n        timestep = args[0] if len(args) > 0 else kwargs.pop(\"timestep\", None)\n        prev_timestep = args[1] if len(args) > 1 else kwargs.pop(\n            \"prev_timestep\", None)\n        if sample is None:\n            if len(args) > 2:\n                sample = args[2]\n            else:\n                raise ValueError(\n                    \" missing `sample` as a required keyward argument\")\n        if timestep is not None:\n            deprecate(\n                \"timesteps\",\n                \"1.0.0\",\n                \"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        if prev_timestep is not None:\n            deprecate(\n                \"prev_timestep\",\n                \"1.0.0\",\n                \"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[\n            self.step_index]  # pyright: ignore\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s = torch.log(alpha_s) - torch.log(sigma_s)\n\n        h = lambda_t - lambda_s\n        if self.config.algorithm_type == \"dpmsolver++\":\n            x_t = (sigma_t /\n                   sigma_s) * sample - (alpha_t *\n                                        (torch.exp(-h) - 1.0)) * model_output\n        elif self.config.algorithm_type == \"dpmsolver\":\n            x_t = (alpha_t /\n                   alpha_s) * sample - (sigma_t *\n                                        (torch.exp(h) - 1.0)) * model_output\n        elif self.config.algorithm_type == \"sde-dpmsolver++\":\n            assert noise is not None\n            x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +\n                   (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +\n                   sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)\n        elif self.config.algorithm_type == \"sde-dpmsolver\":\n            assert noise is not None\n            x_t = ((alpha_t / alpha_s) * sample - 2.0 *\n                   (sigma_t * (torch.exp(h) - 1.0)) * model_output +\n                   sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)\n        return x_t  # pyright: ignore\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update\n    def multistep_dpm_solver_second_order_update(\n        self,\n        model_output_list: List[torch.Tensor],\n        *args,\n        sample: torch.Tensor = None,\n        noise: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the second-order multistep DPMSolver.\n        Args:\n            model_output_list (`List[torch.Tensor]`):\n                The direct outputs from learned diffusion model at current and latter timesteps.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n        Returns:\n            `torch.Tensor`:\n                The sample tensor at the previous timestep.\n        \"\"\"\n        timestep_list = args[0] if len(args) > 0 else kwargs.pop(\n            \"timestep_list\", None)\n        prev_timestep = args[1] if len(args) > 1 else kwargs.pop(\n            \"prev_timestep\", None)\n        if sample is None:\n            if len(args) > 2:\n                sample = args[2]\n            else:\n                raise ValueError(\n                    \" missing `sample` as a required keyward argument\")\n        if timestep_list is not None:\n            deprecate(\n                \"timestep_list\",\n                \"1.0.0\",\n                \"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        if prev_timestep is not None:\n            deprecate(\n                \"prev_timestep\",\n                \"1.0.0\",\n                \"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        sigma_t, sigma_s0, sigma_s1 = (\n            self.sigmas[self.step_index + 1],  # pyright: ignore\n            self.sigmas[self.step_index],\n            self.sigmas[self.step_index - 1],  # pyright: ignore\n        )\n\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n        lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)\n\n        m0, m1 = model_output_list[-1], model_output_list[-2]\n\n        h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1\n        r0 = h_0 / h\n        D0, D1 = m0, (1.0 / r0) * (m0 - m1)\n        if self.config.algorithm_type == \"dpmsolver++\":\n            # See https://arxiv.org/abs/2211.01095 for detailed derivations\n            if self.config.solver_type == \"midpoint\":\n                x_t = ((sigma_t / sigma_s0) * sample -\n                       (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *\n                       (alpha_t * (torch.exp(-h) - 1.0)) * D1)\n            elif self.config.solver_type == \"heun\":\n                x_t = ((sigma_t / sigma_s0) * sample -\n                       (alpha_t * (torch.exp(-h) - 1.0)) * D0 +\n                       (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)\n        elif self.config.algorithm_type == \"dpmsolver\":\n            # See https://arxiv.org/abs/2206.00927 for detailed derivations\n            if self.config.solver_type == \"midpoint\":\n                x_t = ((alpha_t / alpha_s0) * sample -\n                       (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *\n                       (sigma_t * (torch.exp(h) - 1.0)) * D1)\n            elif self.config.solver_type == \"heun\":\n                x_t = ((alpha_t / alpha_s0) * sample -\n                       (sigma_t * (torch.exp(h) - 1.0)) * D0 -\n                       (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)\n        elif self.config.algorithm_type == \"sde-dpmsolver++\":\n            assert noise is not None\n            if self.config.solver_type == \"midpoint\":\n                x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +\n                       (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *\n                       (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +\n                       sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)\n            elif self.config.solver_type == \"heun\":\n                x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +\n                       (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +\n                       (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /\n                                   (-2.0 * h) + 1.0)) * D1 +\n                       sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)\n        elif self.config.algorithm_type == \"sde-dpmsolver\":\n            assert noise is not None\n            if self.config.solver_type == \"midpoint\":\n                x_t = ((alpha_t / alpha_s0) * sample - 2.0 *\n                       (sigma_t * (torch.exp(h) - 1.0)) * D0 -\n                       (sigma_t * (torch.exp(h) - 1.0)) * D1 +\n                       sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)\n            elif self.config.solver_type == \"heun\":\n                x_t = ((alpha_t / alpha_s0) * sample - 2.0 *\n                       (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *\n                       (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +\n                       sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)\n        return x_t  # pyright: ignore\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update\n    def multistep_dpm_solver_third_order_update(\n        self,\n        model_output_list: List[torch.Tensor],\n        *args,\n        sample: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the third-order multistep DPMSolver.\n        Args:\n            model_output_list (`List[torch.Tensor]`):\n                The direct outputs from learned diffusion model at current and latter timesteps.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by diffusion process.\n        Returns:\n            `torch.Tensor`:\n                The sample tensor at the previous timestep.\n        \"\"\"\n\n        timestep_list = args[0] if len(args) > 0 else kwargs.pop(\n            \"timestep_list\", None)\n        prev_timestep = args[1] if len(args) > 1 else kwargs.pop(\n            \"prev_timestep\", None)\n        if sample is None:\n            if len(args) > 2:\n                sample = args[2]\n            else:\n                raise ValueError(\n                    \" missing`sample` as a required keyward argument\")\n        if timestep_list is not None:\n            deprecate(\n                \"timestep_list\",\n                \"1.0.0\",\n                \"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        if prev_timestep is not None:\n            deprecate(\n                \"prev_timestep\",\n                \"1.0.0\",\n                \"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        sigma_t, sigma_s0, sigma_s1, sigma_s2 = (\n            self.sigmas[self.step_index + 1],  # pyright: ignore\n            self.sigmas[self.step_index],\n            self.sigmas[self.step_index - 1],  # pyright: ignore\n            self.sigmas[self.step_index - 2],  # pyright: ignore\n        )\n\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n        alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)\n        alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n        lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)\n        lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)\n\n        m0, m1, m2 = model_output_list[-1], model_output_list[\n            -2], model_output_list[-3]\n\n        h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2\n        r0, r1 = h_0 / h, h_1 / h\n        D0 = m0\n        D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)\n        D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)\n        D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)\n        if self.config.algorithm_type == \"dpmsolver++\":\n            # See https://arxiv.org/abs/2206.00927 for detailed derivations\n            x_t = ((sigma_t / sigma_s0) * sample -\n                   (alpha_t * (torch.exp(-h) - 1.0)) * D0 +\n                   (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -\n                   (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)\n        elif self.config.algorithm_type == \"dpmsolver\":\n            # See https://arxiv.org/abs/2206.00927 for detailed derivations\n            x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *\n                                                    (torch.exp(h) - 1.0)) * D0 -\n                   (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -\n                   (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)\n        return x_t  # pyright: ignore\n\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        pos = 1 if len(indices) > 1 else 0\n\n        return indices[pos].item()\n\n    def _init_step_index(self, timestep):\n        \"\"\"\n        Initialize the step_index counter for the scheduler.\n        \"\"\"\n\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: Union[int, torch.Tensor],\n        sample: torch.Tensor,\n        generator=None,\n        variance_noise: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[SchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with\n        the multistep DPMSolver.\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            variance_noise (`torch.Tensor`):\n                Alternative to generating noise with `generator` by directly providing the noise for the variance\n                itself. Useful for methods such as [`LEdits++`].\n            return_dict (`bool`):\n                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.\n        Returns:\n            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        # Improve numerical stability for small number of steps\n        lower_order_final = (self.step_index == len(self.timesteps) - 1) and (\n            self.config.euler_at_final or\n            (self.config.lower_order_final and len(self.timesteps) < 15) or\n            self.config.final_sigmas_type == \"zero\")\n        lower_order_second = ((self.step_index == len(self.timesteps) - 2) and\n                              self.config.lower_order_final and\n                              len(self.timesteps) < 15)\n\n        model_output = self.convert_model_output(model_output, sample=sample)\n        for i in range(self.config.solver_order - 1):\n            self.model_outputs[i] = self.model_outputs[i + 1]\n        self.model_outputs[-1] = model_output\n\n        # Upcast to avoid precision issues when computing prev_sample\n        sample = sample.to(torch.float32)\n        if self.config.algorithm_type in [\"sde-dpmsolver\", \"sde-dpmsolver++\"\n                                         ] and variance_noise is None:\n            noise = randn_tensor(\n                model_output.shape,\n                generator=generator,\n                device=model_output.device,\n                dtype=torch.float32)\n        elif self.config.algorithm_type in [\"sde-dpmsolver\", \"sde-dpmsolver++\"]:\n            noise = variance_noise.to(\n                device=model_output.device,\n                dtype=torch.float32)  # pyright: ignore\n        else:\n            noise = None\n\n        if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:\n            prev_sample = self.dpm_solver_first_order_update(\n                model_output, sample=sample, noise=noise)\n        elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:\n            prev_sample = self.multistep_dpm_solver_second_order_update(\n                self.model_outputs, sample=sample, noise=noise)\n        else:\n            prev_sample = self.multistep_dpm_solver_third_order_update(\n                self.model_outputs, sample=sample)\n\n        if self.lower_order_nums < self.config.solver_order:\n            self.lower_order_nums += 1\n\n        # Cast sample back to expected dtype\n        prev_sample = prev_sample.to(model_output.dtype)\n\n        # upon completion increase step index by one\n        self._step_index += 1  # pyright: ignore\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return SchedulerOutput(prev_sample=prev_sample)\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input\n    def scale_model_input(self, sample: torch.Tensor, *args,\n                          **kwargs) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure sigmas and timesteps have the same device and dtype as original_samples\n        sigmas = self.sigmas.to(\n            device=original_samples.device, dtype=original_samples.dtype)\n        if original_samples.device.type == \"mps\" and torch.is_floating_point(\n                timesteps):\n            # mps does not support float64\n            schedule_timesteps = self.timesteps.to(\n                original_samples.device, dtype=torch.float32)\n            timesteps = timesteps.to(\n                original_samples.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(original_samples.device)\n            timesteps = timesteps.to(original_samples.device)\n\n        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index\n        if self.begin_index is None:\n            step_indices = [\n                self.index_for_timestep(t, schedule_timesteps)\n                for t in timesteps\n            ]\n        elif self.step_index is not None:\n            # add_noise is called after first denoising step (for inpainting)\n            step_indices = [self.step_index] * timesteps.shape[0]\n        else:\n            # add noise is called before first denoising step to create initial latent(img2img)\n            step_indices = [self.begin_index] * timesteps.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(original_samples.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n        noisy_samples = alpha_t * original_samples + sigma_t * noise\n        return noisy_samples\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n"
  },
  {
    "path": "wan/utils/fm_solvers_unipc.py",
    "content": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py\n# Convert unipc for flow matching\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import (\n    KarrasDiffusionSchedulers,\n    SchedulerMixin,\n    SchedulerOutput,\n)\nfrom diffusers.utils import deprecate, is_scipy_available\n\nif is_scipy_available():\n    import scipy.stats\n\n\nclass FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        solver_order (`int`, default `2`):\n            The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`\n            due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for\n            unconditional sampling.\n        prediction_type (`str`, defaults to \"flow_prediction\"):\n            Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts\n            the flow of the diffusion process.\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.\n        predict_x0 (`bool`, defaults to `True`):\n            Whether to use the updating algorithm on the predicted x0.\n        solver_type (`str`, default `bh2`):\n            Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`\n            otherwise.\n        lower_order_final (`bool`, default `True`):\n            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can\n            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.\n        disable_corrector (`list`, default `[]`):\n            Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`\n            and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is\n            usually disabled during the first few steps.\n        solver_p (`SchedulerMixin`, default `None`):\n            Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.\n        use_karras_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,\n            the sigmas are determined according to a sequence of noise levels {σi}.\n        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):\n            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.\n        timestep_spacing (`str`, defaults to `\"linspace\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        final_sigmas_type (`str`, defaults to `\"zero\"`):\n            The final `sigma` value for the noise schedule during the sampling process. If `\"sigma_min\"`, the final\n            sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.\n    \"\"\"\n\n    _compatibles = [e.name for e in KarrasDiffusionSchedulers]\n    order = 1\n\n    @register_to_config\n    def __init__(\n            self,\n            num_train_timesteps: int = 1000,\n            solver_order: int = 2,\n            prediction_type: str = \"flow_prediction\",\n            shift: Optional[float] = 1.0,\n            use_dynamic_shifting=False,\n            thresholding: bool = False,\n            dynamic_thresholding_ratio: float = 0.995,\n            sample_max_value: float = 1.0,\n            predict_x0: bool = True,\n            solver_type: str = \"bh2\",\n            lower_order_final: bool = True,\n            disable_corrector: List[int] = [],\n            solver_p: SchedulerMixin = None,\n            timestep_spacing: str = \"linspace\",\n            steps_offset: int = 0,\n            final_sigmas_type: Optional[str] = \"zero\",  # \"zero\", \"sigma_min\"\n    ):\n\n        if solver_type not in [\"bh1\", \"bh2\"]:\n            if solver_type in [\"midpoint\", \"heun\", \"logrho\"]:\n                self.register_to_config(solver_type=\"bh2\")\n            else:\n                raise NotImplementedError(\n                    f\"{solver_type} is not implemented for {self.__class__}\")\n\n        self.predict_x0 = predict_x0\n        # setable values\n        self.num_inference_steps = None\n        alphas = np.linspace(1, 1 / num_train_timesteps,\n                             num_train_timesteps)[::-1].copy()\n        sigmas = 1.0 - alphas\n        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)\n\n        if not use_dynamic_shifting:\n            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution\n            sigmas = shift * sigmas / (1 +\n                                       (shift - 1) * sigmas)  # pyright: ignore\n\n        self.sigmas = sigmas\n        self.timesteps = sigmas * num_train_timesteps\n\n        self.model_outputs = [None] * solver_order\n        self.timestep_list = [None] * solver_order\n        self.lower_order_nums = 0\n        self.disable_corrector = disable_corrector\n        self.solver_p = solver_p\n        self.last_sample = None\n        self._step_index = None\n        self._begin_index = None\n\n        self.sigmas = self.sigmas.to(\n            \"cpu\")  # to avoid too much CPU/GPU communication\n        self.sigma_min = self.sigmas[-1].item()\n        self.sigma_max = self.sigmas[0].item()\n\n    @property\n    def step_index(self):\n        \"\"\"\n        The index counter for current timestep. It will increase 1 after each scheduler step.\n        \"\"\"\n        return self._step_index\n\n    @property\n    def begin_index(self):\n        \"\"\"\n        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.\n        \"\"\"\n        return self._begin_index\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index\n    def set_begin_index(self, begin_index: int = 0):\n        \"\"\"\n        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.\n\n        Args:\n            begin_index (`int`):\n                The begin index for the scheduler.\n        \"\"\"\n        self._begin_index = begin_index\n\n    # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps\n    def set_timesteps(\n        self,\n        num_inference_steps: Union[int, None] = None,\n        device: Union[str, torch.device] = None,\n        sigmas: Optional[List[float]] = None,\n        mu: Optional[Union[float, None]] = None,\n        shift: Optional[Union[float, None]] = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n        Args:\n            num_inference_steps (`int`):\n                Total number of the spacing of the time steps.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        \"\"\"\n\n        if self.config.use_dynamic_shifting and mu is None:\n            raise ValueError(\n                \" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`\"\n            )\n\n        if sigmas is None:\n            sigmas = np.linspace(self.sigma_max, self.sigma_min,\n                                 num_inference_steps +\n                                 1).copy()[:-1]  # pyright: ignore\n\n        if self.config.use_dynamic_shifting:\n            sigmas = self.time_shift(mu, 1.0, sigmas)  # pyright: ignore\n        else:\n            if shift is None:\n                shift = self.config.shift\n            sigmas = shift * sigmas / (1 +\n                                       (shift - 1) * sigmas)  # pyright: ignore\n\n        if self.config.final_sigmas_type == \"sigma_min\":\n            sigma_last = ((1 - self.alphas_cumprod[0]) /\n                          self.alphas_cumprod[0])**0.5\n        elif self.config.final_sigmas_type == \"zero\":\n            sigma_last = 0\n        else:\n            raise ValueError(\n                f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}\"\n            )\n\n        timesteps = sigmas * self.config.num_train_timesteps\n        sigmas = np.concatenate([sigmas, [sigma_last]\n                                ]).astype(np.float32)  # pyright: ignore\n\n        self.sigmas = torch.from_numpy(sigmas)\n        self.timesteps = torch.from_numpy(timesteps).to(\n            device=device, dtype=torch.int64)\n\n        self.num_inference_steps = len(timesteps)\n\n        self.model_outputs = [\n            None,\n        ] * self.config.solver_order\n        self.lower_order_nums = 0\n        self.last_sample = None\n        if self.solver_p:\n            self.solver_p.set_timesteps(self.num_inference_steps, device=device)\n\n        # add an index counter for schedulers that allow duplicated timesteps\n        self._step_index = None\n        self._begin_index = None\n        self.sigmas = self.sigmas.to(\n            \"cpu\")  # to avoid too much CPU/GPU communication\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n\n        https://arxiv.org/abs/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, *remaining_dims = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = sample.float(\n            )  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(\n            abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n        s = s.unsqueeze(\n            1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = torch.clamp(\n            sample, -s, s\n        ) / s  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, *remaining_dims)\n        sample = sample.to(dtype)\n\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t\n    def _sigma_to_t(self, sigma):\n        return sigma * self.config.num_train_timesteps\n\n    def _sigma_to_alpha_sigma_t(self, sigma):\n        return 1 - sigma, sigma\n\n    # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps\n    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):\n        return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)\n\n    def convert_model_output(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Convert the model output to the corresponding type the UniPC algorithm needs.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n\n        Returns:\n            `torch.Tensor`:\n                The converted model output.\n        \"\"\"\n        timestep = args[0] if len(args) > 0 else kwargs.pop(\"timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\n                    \"missing `sample` as a required keyward argument\")\n        if timestep is not None:\n            deprecate(\n                \"timesteps\",\n                \"1.0.0\",\n                \"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        sigma = self.sigmas[self.step_index]\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n\n        if self.predict_x0:\n            if self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                    \" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                x0_pred = self._threshold_sample(x0_pred)\n\n            return x0_pred\n        else:\n            if self.config.prediction_type == \"flow_prediction\":\n                sigma_t = self.sigmas[self.step_index]\n                epsilon = sample - (1 - sigma_t) * model_output\n            else:\n                raise ValueError(\n                    f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                    \" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler.\"\n                )\n\n            if self.config.thresholding:\n                sigma_t = self.sigmas[self.step_index]\n                x0_pred = sample - sigma_t * model_output\n                x0_pred = self._threshold_sample(x0_pred)\n                epsilon = model_output + x0_pred\n\n            return epsilon\n\n    def multistep_uni_p_bh_update(\n        self,\n        model_output: torch.Tensor,\n        *args,\n        sample: torch.Tensor = None,\n        order: int = None,  # pyright: ignore\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from the learned diffusion model at the current timestep.\n            prev_timestep (`int`):\n                The previous discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            order (`int`):\n                The order of UniP at this timestep (corresponds to the *p* in UniPC-p).\n\n        Returns:\n            `torch.Tensor`:\n                The sample tensor at the previous timestep.\n        \"\"\"\n        prev_timestep = args[0] if len(args) > 0 else kwargs.pop(\n            \"prev_timestep\", None)\n        if sample is None:\n            if len(args) > 1:\n                sample = args[1]\n            else:\n                raise ValueError(\n                    \" missing `sample` as a required keyward argument\")\n        if order is None:\n            if len(args) > 2:\n                order = args[2]\n            else:\n                raise ValueError(\n                    \" missing `order` as a required keyward argument\")\n        if prev_timestep is not None:\n            deprecate(\n                \"prev_timestep\",\n                \"1.0.0\",\n                \"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n        model_output_list = self.model_outputs\n\n        s0 = self.timestep_list[-1]\n        m0 = model_output_list[-1]\n        x = sample\n\n        if self.solver_p:\n            x_t = self.solver_p.step(model_output, s0, x).prev_sample\n            return x_t\n\n        sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[\n            self.step_index]  # pyright: ignore\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n\n        h = lambda_t - lambda_s0\n        device = sample.device\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            si = self.step_index - i  # pyright: ignore\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            D1s.append((mi - m0) / rk)  # pyright: ignore\n\n        rks.append(1.0)\n        rks = torch.tensor(rks, device=device)\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)  # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=device)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)  # (B, K)\n            # for order 2, we use a simplified version\n            if order == 2:\n                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)\n            else:\n                rhos_p = torch.linalg.solve(R[:-1, :-1],\n                                            b[:-1]).to(device).to(x.dtype)\n        else:\n            D1s = None\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            if D1s is not None:\n                pred_res = torch.einsum(\"k,bkc...->bc...\", rhos_p,\n                                        D1s)  # pyright: ignore\n            else:\n                pred_res = 0\n            x_t = x_t_ - alpha_t * B_h * pred_res\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            if D1s is not None:\n                pred_res = torch.einsum(\"k,bkc...->bc...\", rhos_p,\n                                        D1s)  # pyright: ignore\n            else:\n                pred_res = 0\n            x_t = x_t_ - sigma_t * B_h * pred_res\n\n        x_t = x_t.to(x.dtype)\n        return x_t\n\n    def multistep_uni_c_bh_update(\n        self,\n        this_model_output: torch.Tensor,\n        *args,\n        last_sample: torch.Tensor = None,\n        this_sample: torch.Tensor = None,\n        order: int = None,  # pyright: ignore\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        One step for the UniC (B(h) version).\n\n        Args:\n            this_model_output (`torch.Tensor`):\n                The model outputs at `x_t`.\n            this_timestep (`int`):\n                The current timestep `t`.\n            last_sample (`torch.Tensor`):\n                The generated sample before the last predictor `x_{t-1}`.\n            this_sample (`torch.Tensor`):\n                The generated sample after the last predictor `x_{t}`.\n            order (`int`):\n                The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.\n\n        Returns:\n            `torch.Tensor`:\n                The corrected sample tensor at the current timestep.\n        \"\"\"\n        this_timestep = args[0] if len(args) > 0 else kwargs.pop(\n            \"this_timestep\", None)\n        if last_sample is None:\n            if len(args) > 1:\n                last_sample = args[1]\n            else:\n                raise ValueError(\n                    \" missing`last_sample` as a required keyward argument\")\n        if this_sample is None:\n            if len(args) > 2:\n                this_sample = args[2]\n            else:\n                raise ValueError(\n                    \" missing`this_sample` as a required keyward argument\")\n        if order is None:\n            if len(args) > 3:\n                order = args[3]\n            else:\n                raise ValueError(\n                    \" missing`order` as a required keyward argument\")\n        if this_timestep is not None:\n            deprecate(\n                \"this_timestep\",\n                \"1.0.0\",\n                \"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`\",\n            )\n\n        model_output_list = self.model_outputs\n\n        m0 = model_output_list[-1]\n        x = last_sample\n        x_t = this_sample\n        model_t = this_model_output\n\n        sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[\n            self.step_index - 1]  # pyright: ignore\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)\n        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)\n\n        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)\n        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)\n\n        h = lambda_t - lambda_s0\n        device = this_sample.device\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            si = self.step_index - (i + 1)  # pyright: ignore\n            mi = model_output_list[-(i + 1)]\n            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])\n            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)\n            rk = (lambda_si - lambda_s0) / h\n            rks.append(rk)\n            D1s.append((mi - m0) / rk)  # pyright: ignore\n\n        rks.append(1.0)\n        rks = torch.tensor(rks, device=device)\n\n        R = []\n        b = []\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)  # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.config.solver_type == \"bh1\":\n            B_h = hh\n        elif self.config.solver_type == \"bh2\":\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= i + 1\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=device)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1)\n        else:\n            D1s = None\n\n        # for order 1, we use a simplified version\n        if order == 1:\n            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)\n        else:\n            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)\n\n        if self.predict_x0:\n            x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0\n            if D1s is not None:\n                corr_res = torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n            else:\n                corr_res = 0\n            D1_t = model_t - m0\n            x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        else:\n            x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0\n            if D1s is not None:\n                corr_res = torch.einsum(\"k,bkc...->bc...\", rhos_c[:-1], D1s)\n            else:\n                corr_res = 0\n            D1_t = model_t - m0\n            x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)\n        x_t = x_t.to(x.dtype)\n        return x_t\n\n    def index_for_timestep(self, timestep, schedule_timesteps=None):\n        if schedule_timesteps is None:\n            schedule_timesteps = self.timesteps\n\n        indices = (schedule_timesteps == timestep).nonzero()\n\n        # The sigma index that is taken for the **very** first `step`\n        # is always the second index (or the last index if there is only 1)\n        # This way we can ensure we don't accidentally skip a sigma in\n        # case we start in the middle of the denoising schedule (e.g. for image-to-image)\n        pos = 1 if len(indices) > 1 else 0\n\n        return indices[pos].item()\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index\n    def _init_step_index(self, timestep):\n        \"\"\"\n        Initialize the step_index counter for the scheduler.\n        \"\"\"\n\n        if self.begin_index is None:\n            if isinstance(timestep, torch.Tensor):\n                timestep = timestep.to(self.timesteps.device)\n            self._step_index = self.index_for_timestep(timestep)\n        else:\n            self._step_index = self._begin_index\n\n    def step(self,\n             model_output: torch.Tensor,\n             timestep: Union[int, torch.Tensor],\n             sample: torch.Tensor,\n             return_dict: bool = True,\n             generator=None) -> Union[SchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with\n        the multistep UniPC.\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`int`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            return_dict (`bool`):\n                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.\n\n        Returns:\n            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        if self.step_index is None:\n            self._init_step_index(timestep)\n\n        use_corrector = (\n            self.step_index > 0 and\n            self.step_index - 1 not in self.disable_corrector and\n            self.last_sample is not None  # pyright: ignore\n        )\n\n        model_output_convert = self.convert_model_output(\n            model_output, sample=sample)\n        if use_corrector:\n            sample = self.multistep_uni_c_bh_update(\n                this_model_output=model_output_convert,\n                last_sample=self.last_sample,\n                this_sample=sample,\n                order=self.this_order,\n            )\n\n        for i in range(self.config.solver_order - 1):\n            self.model_outputs[i] = self.model_outputs[i + 1]\n            self.timestep_list[i] = self.timestep_list[i + 1]\n\n        self.model_outputs[-1] = model_output_convert\n        self.timestep_list[-1] = timestep  # pyright: ignore\n\n        if self.config.lower_order_final:\n            this_order = min(self.config.solver_order,\n                             len(self.timesteps) -\n                             self.step_index)  # pyright: ignore\n        else:\n            this_order = self.config.solver_order\n\n        self.this_order = min(this_order,\n                              self.lower_order_nums + 1)  # warmup for multistep\n        assert self.this_order > 0\n\n        self.last_sample = sample\n        prev_sample = self.multistep_uni_p_bh_update(\n            model_output=model_output,  # pass the original non-converted model output, in case solver-p is used\n            sample=sample,\n            order=self.this_order,\n        )\n\n        if self.lower_order_nums < self.config.solver_order:\n            self.lower_order_nums += 1\n\n        # upon completion increase step index by one\n        self._step_index += 1  # pyright: ignore\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return SchedulerOutput(prev_sample=prev_sample)\n\n    def scale_model_input(self, sample: torch.Tensor, *args,\n                          **kwargs) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure sigmas and timesteps have the same device and dtype as original_samples\n        sigmas = self.sigmas.to(\n            device=original_samples.device, dtype=original_samples.dtype)\n        if original_samples.device.type == \"mps\" and torch.is_floating_point(\n                timesteps):\n            # mps does not support float64\n            schedule_timesteps = self.timesteps.to(\n                original_samples.device, dtype=torch.float32)\n            timesteps = timesteps.to(\n                original_samples.device, dtype=torch.float32)\n        else:\n            schedule_timesteps = self.timesteps.to(original_samples.device)\n            timesteps = timesteps.to(original_samples.device)\n\n        # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index\n        if self.begin_index is None:\n            step_indices = [\n                self.index_for_timestep(t, schedule_timesteps)\n                for t in timesteps\n            ]\n        elif self.step_index is not None:\n            # add_noise is called after first denoising step (for inpainting)\n            step_indices = [self.step_index] * timesteps.shape[0]\n        else:\n            # add noise is called before first denoising step to create initial latent(img2img)\n            step_indices = [self.begin_index] * timesteps.shape[0]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < len(original_samples.shape):\n            sigma = sigma.unsqueeze(-1)\n\n        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)\n        noisy_samples = alpha_t * original_samples + sigma_t * noise\n        return noisy_samples\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n"
  },
  {
    "path": "wan/utils/prompt_extend.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport tempfile\nfrom dataclasses import dataclass\nfrom http import HTTPStatus\nfrom typing import Optional, Union\n\nimport dashscope\nimport torch\nfrom PIL import Image\n\ntry:\n    from flash_attn import flash_attn_varlen_func\n    FLASH_VER = 2\nexcept ModuleNotFoundError:\n    flash_attn_varlen_func = None  # in compatible with CPU machines\n    FLASH_VER = None\n\nfrom .system_prompt import *\n\nDEFAULT_SYS_PROMPTS = {\n    \"t2v-A14B\": {\n        \"zh\": T2V_A14B_ZH_SYS_PROMPT,\n        \"en\": T2V_A14B_EN_SYS_PROMPT,\n    },\n    \"i2v-A14B\": {\n        \"zh\": I2V_A14B_ZH_SYS_PROMPT,\n        \"en\": I2V_A14B_EN_SYS_PROMPT,\n        \"empty\": {\n            \"zh\": I2V_A14B_EMPTY_ZH_SYS_PROMPT,\n            \"en\": I2V_A14B_EMPTY_EN_SYS_PROMPT,\n        }\n    },\n    \"ti2v-5B\": {\n        \"t2v\": {\n            \"zh\": T2V_A14B_ZH_SYS_PROMPT,\n            \"en\": T2V_A14B_EN_SYS_PROMPT,\n        },\n        \"i2v\": {\n            \"zh\": I2V_A14B_ZH_SYS_PROMPT,\n            \"en\": I2V_A14B_EN_SYS_PROMPT,\n        }\n    },\n}\n\n\n@dataclass\nclass PromptOutput(object):\n    status: bool\n    prompt: str\n    seed: int\n    system_prompt: str\n    message: str\n\n    def add_custom_field(self, key: str, value) -> None:\n        self.__setattr__(key, value)\n\n\nclass PromptExpander:\n\n    def __init__(self, model_name, task, is_vl=False, device=0, **kwargs):\n        self.model_name = model_name\n        self.task = task\n        self.is_vl = is_vl\n        self.device = device\n\n    def extend_with_img(self,\n                        prompt,\n                        system_prompt,\n                        image=None,\n                        seed=-1,\n                        *args,\n                        **kwargs):\n        pass\n\n    def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):\n        pass\n\n    def decide_system_prompt(self, tar_lang=\"zh\", prompt=None):\n        assert self.task is not None\n        if \"ti2v\" in self.task:\n            if self.is_vl:\n                return DEFAULT_SYS_PROMPTS[self.task][\"i2v\"][tar_lang]\n            else:\n                return DEFAULT_SYS_PROMPTS[self.task][\"t2v\"][tar_lang]\n        if \"i2v\" in self.task and len(prompt) == 0:\n            return DEFAULT_SYS_PROMPTS[self.task][\"empty\"][tar_lang]\n        return DEFAULT_SYS_PROMPTS[self.task][tar_lang]\n\n    def __call__(self,\n                 prompt,\n                 system_prompt=None,\n                 tar_lang=\"zh\",\n                 image=None,\n                 seed=-1,\n                 *args,\n                 **kwargs):\n        if system_prompt is None:\n            system_prompt = self.decide_system_prompt(\n                tar_lang=tar_lang, prompt=prompt)\n        if seed < 0:\n            seed = random.randint(0, sys.maxsize)\n        if image is not None and self.is_vl:\n            return self.extend_with_img(\n                prompt, system_prompt, image=image, seed=seed, *args, **kwargs)\n        elif not self.is_vl:\n            return self.extend(prompt, system_prompt, seed, *args, **kwargs)\n        else:\n            raise NotImplementedError\n\n\nclass DashScopePromptExpander(PromptExpander):\n\n    def __init__(self,\n                 api_key=None,\n                 model_name=None,\n                 task=None,\n                 max_image_size=512 * 512,\n                 retry_times=4,\n                 is_vl=False,\n                 **kwargs):\n        '''\n        Args:\n            api_key: The API key for Dash Scope authentication and access to related services.\n            model_name: Model name, 'qwen-plus' for extending prompts, 'qwen-vl-max' for extending prompt-images.\n            task: Task name. This is required to determine the default system prompt.\n            max_image_size: The maximum size of the image; unit unspecified (e.g., pixels, KB). Please specify the unit based on actual usage.\n            retry_times: Number of retry attempts in case of request failure.\n            is_vl: A flag indicating whether the task involves visual-language processing.\n            **kwargs: Additional keyword arguments that can be passed to the function or method.\n        '''\n        if model_name is None:\n            model_name = 'qwen-plus' if not is_vl else 'qwen-vl-max'\n        super().__init__(model_name, task, is_vl, **kwargs)\n        if api_key is not None:\n            dashscope.api_key = api_key\n        elif 'DASH_API_KEY' in os.environ and os.environ[\n                'DASH_API_KEY'] is not None:\n            dashscope.api_key = os.environ['DASH_API_KEY']\n        else:\n            raise ValueError(\"DASH_API_KEY is not set\")\n        if 'DASH_API_URL' in os.environ and os.environ[\n                'DASH_API_URL'] is not None:\n            dashscope.base_http_api_url = os.environ['DASH_API_URL']\n        else:\n            dashscope.base_http_api_url = 'https://dashscope.aliyuncs.com/api/v1'\n        self.api_key = api_key\n\n        self.max_image_size = max_image_size\n        self.model = model_name\n        self.retry_times = retry_times\n\n    def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):\n        messages = [{\n            'role': 'system',\n            'content': system_prompt\n        }, {\n            'role': 'user',\n            'content': prompt\n        }]\n\n        exception = None\n        for _ in range(self.retry_times):\n            try:\n                response = dashscope.Generation.call(\n                    self.model,\n                    messages=messages,\n                    seed=seed,\n                    result_format='message',  # set the result to be \"message\" format.\n                )\n                assert response.status_code == HTTPStatus.OK, response\n                expanded_prompt = response['output']['choices'][0]['message'][\n                    'content']\n                return PromptOutput(\n                    status=True,\n                    prompt=expanded_prompt,\n                    seed=seed,\n                    system_prompt=system_prompt,\n                    message=json.dumps(response, ensure_ascii=False))\n            except Exception as e:\n                exception = e\n        return PromptOutput(\n            status=False,\n            prompt=prompt,\n            seed=seed,\n            system_prompt=system_prompt,\n            message=str(exception))\n\n    def extend_with_img(self,\n                        prompt,\n                        system_prompt,\n                        image: Union[Image.Image, str] = None,\n                        seed=-1,\n                        *args,\n                        **kwargs):\n        if isinstance(image, str):\n            image = Image.open(image).convert('RGB')\n        w = image.width\n        h = image.height\n        area = min(w * h, self.max_image_size)\n        aspect_ratio = h / w\n        resized_h = round(math.sqrt(area * aspect_ratio))\n        resized_w = round(math.sqrt(area / aspect_ratio))\n        image = image.resize((resized_w, resized_h))\n        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:\n            image.save(f.name)\n            fname = f.name\n            image_path = f\"file://{f.name}\"\n        prompt = f\"{prompt}\"\n        messages = [\n            {\n                'role': 'system',\n                'content': [{\n                    \"text\": system_prompt\n                }]\n            },\n            {\n                'role': 'user',\n                'content': [{\n                    \"text\": prompt\n                }, {\n                    \"image\": image_path\n                }]\n            },\n        ]\n        response = None\n        result_prompt = prompt\n        exception = None\n        status = False\n        for _ in range(self.retry_times):\n            try:\n                response = dashscope.MultiModalConversation.call(\n                    self.model,\n                    messages=messages,\n                    seed=seed,\n                    result_format='message',  # set the result to be \"message\" format.\n                )\n                assert response.status_code == HTTPStatus.OK, response\n                result_prompt = response['output']['choices'][0]['message'][\n                    'content'][0]['text'].replace('\\n', '\\\\n')\n                status = True\n                break\n            except Exception as e:\n                exception = e\n        result_prompt = result_prompt.replace('\\n', '\\\\n')\n        os.remove(fname)\n\n        return PromptOutput(\n            status=status,\n            prompt=result_prompt,\n            seed=seed,\n            system_prompt=system_prompt,\n            message=str(exception) if not status else json.dumps(\n                response, ensure_ascii=False))\n\n\nclass QwenPromptExpander(PromptExpander):\n    model_dict = {\n        \"QwenVL2.5_3B\": \"Qwen/Qwen2.5-VL-3B-Instruct\",\n        \"QwenVL2.5_7B\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n        \"Qwen2.5_3B\": \"Qwen/Qwen2.5-3B-Instruct\",\n        \"Qwen2.5_7B\": \"Qwen/Qwen2.5-7B-Instruct\",\n        \"Qwen2.5_14B\": \"Qwen/Qwen2.5-14B-Instruct\",\n    }\n\n    def __init__(self,\n                 model_name=None,\n                 task=None,\n                 device=0,\n                 is_vl=False,\n                 **kwargs):\n        '''\n        Args:\n            model_name: Use predefined model names such as 'QwenVL2.5_7B' and 'Qwen2.5_14B',\n                which are specific versions of the Qwen model. Alternatively, you can use the\n                local path to a downloaded model or the model name from Hugging Face.\"\n              Detailed Breakdown:\n                Predefined Model Names:\n                * 'QwenVL2.5_7B' and 'Qwen2.5_14B' are specific versions of the Qwen model.\n                Local Path:\n                * You can provide the path to a model that you have downloaded locally.\n                Hugging Face Model Name:\n                * You can also specify the model name from Hugging Face's model hub.\n            task: Task name. This is required to determine the default system prompt.\n            is_vl: A flag indicating whether the task involves visual-language processing.\n            **kwargs: Additional keyword arguments that can be passed to the function or method.\n        '''\n        if model_name is None:\n            model_name = 'Qwen2.5_14B' if not is_vl else 'QwenVL2.5_7B'\n        super().__init__(model_name, task, is_vl, device, **kwargs)\n        if (not os.path.exists(self.model_name)) and (self.model_name\n                                                      in self.model_dict):\n            self.model_name = self.model_dict[self.model_name]\n\n        if self.is_vl:\n            # default: Load the model on the available device(s)\n            from transformers import (\n                AutoProcessor,\n                AutoTokenizer,\n                Qwen2_5_VLForConditionalGeneration,\n            )\n            try:\n                from .qwen_vl_utils import process_vision_info\n            except:\n                from qwen_vl_utils import process_vision_info\n            self.process_vision_info = process_vision_info\n            min_pixels = 256 * 28 * 28\n            max_pixels = 1280 * 28 * 28\n            self.processor = AutoProcessor.from_pretrained(\n                self.model_name,\n                min_pixels=min_pixels,\n                max_pixels=max_pixels,\n                use_fast=True)\n            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n                self.model_name,\n                torch_dtype=torch.bfloat16 if FLASH_VER == 2 else\n                torch.float16 if \"AWQ\" in self.model_name else \"auto\",\n                attn_implementation=\"flash_attention_2\"\n                if FLASH_VER == 2 else None,\n                device_map=\"cpu\")\n        else:\n            from transformers import AutoModelForCausalLM, AutoTokenizer\n            self.model = AutoModelForCausalLM.from_pretrained(\n                self.model_name,\n                torch_dtype=torch.float16\n                if \"AWQ\" in self.model_name else \"auto\",\n                attn_implementation=\"flash_attention_2\"\n                if FLASH_VER == 2 else None,\n                device_map=\"cpu\")\n            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n\n    def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):\n        self.model = self.model.to(self.device)\n        messages = [{\n            \"role\": \"system\",\n            \"content\": system_prompt\n        }, {\n            \"role\": \"user\",\n            \"content\": prompt\n        }]\n        text = self.tokenizer.apply_chat_template(\n            messages, tokenize=False, add_generation_prompt=True)\n        model_inputs = self.tokenizer([text],\n                                      return_tensors=\"pt\").to(self.model.device)\n\n        generated_ids = self.model.generate(**model_inputs, max_new_tokens=512)\n        generated_ids = [\n            output_ids[len(input_ids):] for input_ids, output_ids in zip(\n                model_inputs.input_ids, generated_ids)\n        ]\n\n        expanded_prompt = self.tokenizer.batch_decode(\n            generated_ids, skip_special_tokens=True)[0]\n        self.model = self.model.to(\"cpu\")\n        return PromptOutput(\n            status=True,\n            prompt=expanded_prompt,\n            seed=seed,\n            system_prompt=system_prompt,\n            message=json.dumps({\"content\": expanded_prompt},\n                               ensure_ascii=False))\n\n    def extend_with_img(self,\n                        prompt,\n                        system_prompt,\n                        image: Union[Image.Image, str] = None,\n                        seed=-1,\n                        *args,\n                        **kwargs):\n        self.model = self.model.to(self.device)\n        messages = [{\n            'role': 'system',\n            'content': [{\n                \"type\": \"text\",\n                \"text\": system_prompt\n            }]\n        }, {\n            \"role\":\n                \"user\",\n            \"content\": [\n                {\n                    \"type\": \"image\",\n                    \"image\": image,\n                },\n                {\n                    \"type\": \"text\",\n                    \"text\": prompt\n                },\n            ],\n        }]\n\n        # Preparation for inference\n        text = self.processor.apply_chat_template(\n            messages, tokenize=False, add_generation_prompt=True)\n        image_inputs, video_inputs = self.process_vision_info(messages)\n        inputs = self.processor(\n            text=[text],\n            images=image_inputs,\n            videos=video_inputs,\n            padding=True,\n            return_tensors=\"pt\",\n        )\n        inputs = inputs.to(self.device)\n\n        # Inference: Generation of the output\n        generated_ids = self.model.generate(**inputs, max_new_tokens=512)\n        generated_ids_trimmed = [\n            out_ids[len(in_ids):]\n            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n        ]\n        expanded_prompt = self.processor.batch_decode(\n            generated_ids_trimmed,\n            skip_special_tokens=True,\n            clean_up_tokenization_spaces=False)[0]\n        self.model = self.model.to(\"cpu\")\n        return PromptOutput(\n            status=True,\n            prompt=expanded_prompt,\n            seed=seed,\n            system_prompt=system_prompt,\n            message=json.dumps({\"content\": expanded_prompt},\n                               ensure_ascii=False))\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(\n        level=logging.INFO,\n        format=\"[%(asctime)s] %(levelname)s: %(message)s\",\n        handlers=[logging.StreamHandler(stream=sys.stdout)])\n\n    seed = 100\n    prompt = \"夏日海滩度假风格，一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松，表情悠闲，直视镜头。背景是模糊的海滩景色，海水清澈，远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松，仿佛在享受海风和阳光。近景特写，强调猫咪的细节和海滩的清新氛围。\"\n    en_prompt = \"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.\"\n    image = \"./examples/i2v_input.JPG\"\n\n    def test(method,\n             prompt,\n             model_name,\n             task,\n             image=None,\n             en_prompt=None,\n             seed=None):\n        prompt_expander = method(\n            model_name=model_name, task=task, is_vl=image is not None)\n        result = prompt_expander(prompt, image=image, tar_lang=\"zh\")\n        logging.info(f\"zh prompt -> zh: {result.prompt}\")\n        result = prompt_expander(prompt, image=image, tar_lang=\"en\")\n        logging.info(f\"zh prompt -> en: {result.prompt}\")\n        if en_prompt is not None:\n            result = prompt_expander(en_prompt, image=image, tar_lang=\"zh\")\n            logging.info(f\"en prompt -> zh: {result.prompt}\")\n            result = prompt_expander(en_prompt, image=image, tar_lang=\"en\")\n            logging.info(f\"en prompt -> en: {result.prompt}\")\n\n    ds_model_name = None\n    ds_vl_model_name = None\n    qwen_model_name = None\n    qwen_vl_model_name = None\n\n    for task in [\"t2v-A14B\", \"i2v-A14B\", \"ti2v-5B\"]:\n        # test prompt extend\n        if \"t2v\" in task or \"ti2v\" in task:\n            # test dashscope api\n            logging.info(f\"-\" * 40)\n            logging.info(f\"Testing {task} dashscope prompt extend\")\n            test(\n                DashScopePromptExpander,\n                prompt,\n                ds_model_name,\n                task,\n                image=None,\n                en_prompt=en_prompt,\n                seed=seed)\n\n            # test qwen api\n            logging.info(f\"-\" * 40)\n            logging.info(f\"Testing {task} qwen prompt extend\")\n            test(\n                QwenPromptExpander,\n                prompt,\n                qwen_model_name,\n                task,\n                image=None,\n                en_prompt=en_prompt,\n                seed=seed)\n\n        # test prompt-image extend\n        if \"i2v\" in task:\n            # test dashscope api\n            logging.info(f\"-\" * 40)\n            logging.info(f\"Testing {task} dashscope vl prompt extend\")\n            test(\n                DashScopePromptExpander,\n                prompt,\n                ds_vl_model_name,\n                task,\n                image=image,\n                en_prompt=en_prompt,\n                seed=seed)\n\n            # test qwen api\n            logging.info(f\"-\" * 40)\n            logging.info(f\"Testing {task} qwen vl prompt extend\")\n            test(\n                QwenPromptExpander,\n                prompt,\n                qwen_vl_model_name,\n                task,\n                image=image,\n                en_prompt=en_prompt,\n                seed=seed)\n\n        # test empty prompt extend\n        if \"i2v-A14B\" in task:\n            # test dashscope api\n            logging.info(f\"-\" * 40)\n            logging.info(f\"Testing {task} dashscope vl empty prompt extend\")\n            test(\n                DashScopePromptExpander,\n                \"\",\n                ds_vl_model_name,\n                task,\n                image=image,\n                en_prompt=None,\n                seed=seed)\n\n            # test qwen api\n            logging.info(f\"-\" * 40)\n            logging.info(f\"Testing {task} qwen vl empty prompt extend\")\n            test(\n                QwenPromptExpander,\n                \"\",\n                qwen_vl_model_name,\n                task,\n                image=image,\n                en_prompt=None,\n                seed=seed)\n"
  },
  {
    "path": "wan/utils/qwen_vl_utils.py",
    "content": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom __future__ import annotations\n\nimport base64\nimport logging\nimport math\nimport os\nimport sys\nimport time\nimport warnings\nfrom functools import lru_cache\nfrom io import BytesIO\n\nimport requests\nimport torch\nimport torchvision\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import io, transforms\nfrom torchvision.transforms import InterpolationMode\n\nlogger = logging.getLogger(__name__)\n\nIMAGE_FACTOR = 28\nMIN_PIXELS = 4 * 28 * 28\nMAX_PIXELS = 16384 * 28 * 28\nMAX_RATIO = 200\n\nVIDEO_MIN_PIXELS = 128 * 28 * 28\nVIDEO_MAX_PIXELS = 768 * 28 * 28\nVIDEO_TOTAL_PIXELS = 24576 * 28 * 28\nFRAME_FACTOR = 2\nFPS = 2.0\nFPS_MIN_FRAMES = 4\nFPS_MAX_FRAMES = 768\n\n\ndef round_by_factor(number: int, factor: int) -> int:\n    \"\"\"Returns the closest integer to 'number' that is divisible by 'factor'.\"\"\"\n    return round(number / factor) * factor\n\n\ndef ceil_by_factor(number: int, factor: int) -> int:\n    \"\"\"Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.\"\"\"\n    return math.ceil(number / factor) * factor\n\n\ndef floor_by_factor(number: int, factor: int) -> int:\n    \"\"\"Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.\"\"\"\n    return math.floor(number / factor) * factor\n\n\ndef smart_resize(height: int,\n                 width: int,\n                 factor: int = IMAGE_FACTOR,\n                 min_pixels: int = MIN_PIXELS,\n                 max_pixels: int = MAX_PIXELS) -> tuple[int, int]:\n    \"\"\"\n    Rescales the image so that the following conditions are met:\n\n    1. Both dimensions (height and width) are divisible by 'factor'.\n\n    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].\n\n    3. The aspect ratio of the image is maintained as closely as possible.\n    \"\"\"\n    if max(height, width) / min(height, width) > MAX_RATIO:\n        raise ValueError(\n            f\"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}\"\n        )\n    h_bar = max(factor, round_by_factor(height, factor))\n    w_bar = max(factor, round_by_factor(width, factor))\n    if h_bar * w_bar > max_pixels:\n        beta = math.sqrt((height * width) / max_pixels)\n        h_bar = floor_by_factor(height / beta, factor)\n        w_bar = floor_by_factor(width / beta, factor)\n    elif h_bar * w_bar < min_pixels:\n        beta = math.sqrt(min_pixels / (height * width))\n        h_bar = ceil_by_factor(height * beta, factor)\n        w_bar = ceil_by_factor(width * beta, factor)\n    return h_bar, w_bar\n\n\ndef fetch_image(ele: dict[str, str | Image.Image],\n                size_factor: int = IMAGE_FACTOR) -> Image.Image:\n    if \"image\" in ele:\n        image = ele[\"image\"]\n    else:\n        image = ele[\"image_url\"]\n    image_obj = None\n    if isinstance(image, Image.Image):\n        image_obj = image\n    elif image.startswith(\"http://\") or image.startswith(\"https://\"):\n        image_obj = Image.open(requests.get(image, stream=True).raw)\n    elif image.startswith(\"file://\"):\n        image_obj = Image.open(image[7:])\n    elif image.startswith(\"data:image\"):\n        if \"base64,\" in image:\n            _, base64_data = image.split(\"base64,\", 1)\n            data = base64.b64decode(base64_data)\n            image_obj = Image.open(BytesIO(data))\n    else:\n        image_obj = Image.open(image)\n    if image_obj is None:\n        raise ValueError(\n            f\"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}\"\n        )\n    image = image_obj.convert(\"RGB\")\n    ## resize\n    if \"resized_height\" in ele and \"resized_width\" in ele:\n        resized_height, resized_width = smart_resize(\n            ele[\"resized_height\"],\n            ele[\"resized_width\"],\n            factor=size_factor,\n        )\n    else:\n        width, height = image.size\n        min_pixels = ele.get(\"min_pixels\", MIN_PIXELS)\n        max_pixels = ele.get(\"max_pixels\", MAX_PIXELS)\n        resized_height, resized_width = smart_resize(\n            height,\n            width,\n            factor=size_factor,\n            min_pixels=min_pixels,\n            max_pixels=max_pixels,\n        )\n    image = image.resize((resized_width, resized_height))\n\n    return image\n\n\ndef smart_nframes(\n    ele: dict,\n    total_frames: int,\n    video_fps: int | float,\n) -> int:\n    \"\"\"calculate the number of frames for video used for model inputs.\n\n    Args:\n        ele (dict): a dict contains the configuration of video.\n            support either `fps` or `nframes`:\n                - nframes: the number of frames to extract for model inputs.\n                - fps: the fps to extract frames for model inputs.\n                    - min_frames: the minimum number of frames of the video, only used when fps is provided.\n                    - max_frames: the maximum number of frames of the video, only used when fps is provided.\n        total_frames (int): the original total number of frames of the video.\n        video_fps (int | float): the original fps of the video.\n\n    Raises:\n        ValueError: nframes should in interval [FRAME_FACTOR, total_frames].\n\n    Returns:\n        int: the number of frames for video used for model inputs.\n    \"\"\"\n    assert not (\"fps\" in ele and\n                \"nframes\" in ele), \"Only accept either `fps` or `nframes`\"\n    if \"nframes\" in ele:\n        nframes = round_by_factor(ele[\"nframes\"], FRAME_FACTOR)\n    else:\n        fps = ele.get(\"fps\", FPS)\n        min_frames = ceil_by_factor(\n            ele.get(\"min_frames\", FPS_MIN_FRAMES), FRAME_FACTOR)\n        max_frames = floor_by_factor(\n            ele.get(\"max_frames\", min(FPS_MAX_FRAMES, total_frames)),\n            FRAME_FACTOR)\n        nframes = total_frames / video_fps * fps\n        nframes = min(max(nframes, min_frames), max_frames)\n        nframes = round_by_factor(nframes, FRAME_FACTOR)\n    if not (FRAME_FACTOR <= nframes and nframes <= total_frames):\n        raise ValueError(\n            f\"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.\"\n        )\n    return nframes\n\n\ndef _read_video_torchvision(ele: dict,) -> torch.Tensor:\n    \"\"\"read video using torchvision.io.read_video\n\n    Args:\n        ele (dict): a dict contains the configuration of video.\n        support keys:\n            - video: the path of video. support \"file://\", \"http://\", \"https://\" and local path.\n            - video_start: the start time of video.\n            - video_end: the end time of video.\n    Returns:\n        torch.Tensor: the video tensor with shape (T, C, H, W).\n    \"\"\"\n    video_path = ele[\"video\"]\n    if version.parse(torchvision.__version__) < version.parse(\"0.19.0\"):\n        if \"http://\" in video_path or \"https://\" in video_path:\n            warnings.warn(\n                \"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.\"\n            )\n        if \"file://\" in video_path:\n            video_path = video_path[7:]\n    st = time.time()\n    video, audio, info = io.read_video(\n        video_path,\n        start_pts=ele.get(\"video_start\", 0.0),\n        end_pts=ele.get(\"video_end\", None),\n        pts_unit=\"sec\",\n        output_format=\"TCHW\",\n    )\n    total_frames, video_fps = video.size(0), info[\"video_fps\"]\n    logger.info(\n        f\"torchvision:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s\"\n    )\n    nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)\n    idx = torch.linspace(0, total_frames - 1, nframes).round().long()\n    video = video[idx]\n    return video\n\n\ndef is_decord_available() -> bool:\n    import importlib.util\n\n    return importlib.util.find_spec(\"decord\") is not None\n\n\ndef _read_video_decord(ele: dict,) -> torch.Tensor:\n    \"\"\"read video using decord.VideoReader\n\n    Args:\n        ele (dict): a dict contains the configuration of video.\n        support keys:\n            - video: the path of video. support \"file://\", \"http://\", \"https://\" and local path.\n            - video_start: the start time of video.\n            - video_end: the end time of video.\n    Returns:\n        torch.Tensor: the video tensor with shape (T, C, H, W).\n    \"\"\"\n    import decord\n    video_path = ele[\"video\"]\n    st = time.time()\n    vr = decord.VideoReader(video_path)\n    # TODO: support start_pts and end_pts\n    if 'video_start' in ele or 'video_end' in ele:\n        raise NotImplementedError(\n            \"not support start_pts and end_pts in decord for now.\")\n    total_frames, video_fps = len(vr), vr.get_avg_fps()\n    logger.info(\n        f\"decord:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s\"\n    )\n    nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)\n    idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()\n    video = vr.get_batch(idx).asnumpy()\n    video = torch.tensor(video).permute(0, 3, 1, 2)  # Convert to TCHW format\n    return video\n\n\nVIDEO_READER_BACKENDS = {\n    \"decord\": _read_video_decord,\n    \"torchvision\": _read_video_torchvision,\n}\n\nFORCE_QWENVL_VIDEO_READER = os.getenv(\"FORCE_QWENVL_VIDEO_READER\", None)\n\n\n@lru_cache(maxsize=1)\ndef get_video_reader_backend() -> str:\n    if FORCE_QWENVL_VIDEO_READER is not None:\n        video_reader_backend = FORCE_QWENVL_VIDEO_READER\n    elif is_decord_available():\n        video_reader_backend = \"decord\"\n    else:\n        video_reader_backend = \"torchvision\"\n    logger.info(\n        f\"qwen-vl-utils using {video_reader_backend} to read video.\",\n        file=sys.stderr)\n    return video_reader_backend\n\n\ndef fetch_video(\n        ele: dict,\n        image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:\n    if isinstance(ele[\"video\"], str):\n        video_reader_backend = get_video_reader_backend()\n        video = VIDEO_READER_BACKENDS[video_reader_backend](ele)\n        nframes, _, height, width = video.shape\n\n        min_pixels = ele.get(\"min_pixels\", VIDEO_MIN_PIXELS)\n        total_pixels = ele.get(\"total_pixels\", VIDEO_TOTAL_PIXELS)\n        max_pixels = max(\n            min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),\n            int(min_pixels * 1.05))\n        max_pixels = ele.get(\"max_pixels\", max_pixels)\n        if \"resized_height\" in ele and \"resized_width\" in ele:\n            resized_height, resized_width = smart_resize(\n                ele[\"resized_height\"],\n                ele[\"resized_width\"],\n                factor=image_factor,\n            )\n        else:\n            resized_height, resized_width = smart_resize(\n                height,\n                width,\n                factor=image_factor,\n                min_pixels=min_pixels,\n                max_pixels=max_pixels,\n            )\n        video = transforms.functional.resize(\n            video,\n            [resized_height, resized_width],\n            interpolation=InterpolationMode.BICUBIC,\n            antialias=True,\n        ).float()\n        return video\n    else:\n        assert isinstance(ele[\"video\"], (list, tuple))\n        process_info = ele.copy()\n        process_info.pop(\"type\", None)\n        process_info.pop(\"video\", None)\n        images = [\n            fetch_image({\n                \"image\": video_element,\n                **process_info\n            },\n                        size_factor=image_factor)\n            for video_element in ele[\"video\"]\n        ]\n        nframes = ceil_by_factor(len(images), FRAME_FACTOR)\n        if len(images) < nframes:\n            images.extend([images[-1]] * (nframes - len(images)))\n        return images\n\n\ndef extract_vision_info(\n        conversations: list[dict] | list[list[dict]]) -> list[dict]:\n    vision_infos = []\n    if isinstance(conversations[0], dict):\n        conversations = [conversations]\n    for conversation in conversations:\n        for message in conversation:\n            if isinstance(message[\"content\"], list):\n                for ele in message[\"content\"]:\n                    if (\"image\" in ele or \"image_url\" in ele or\n                            \"video\" in ele or\n                            ele[\"type\"] in (\"image\", \"image_url\", \"video\")):\n                        vision_infos.append(ele)\n    return vision_infos\n\n\ndef process_vision_info(\n    conversations: list[dict] | list[list[dict]],\n) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |\n           None]:\n    vision_infos = extract_vision_info(conversations)\n    ## Read images or videos\n    image_inputs = []\n    video_inputs = []\n    for vision_info in vision_infos:\n        if \"image\" in vision_info or \"image_url\" in vision_info:\n            image_inputs.append(fetch_image(vision_info))\n        elif \"video\" in vision_info:\n            video_inputs.append(fetch_video(vision_info))\n        else:\n            raise ValueError(\"image, image_url or video should in content.\")\n    if len(image_inputs) == 0:\n        image_inputs = None\n    if len(video_inputs) == 0:\n        video_inputs = None\n    return image_inputs, video_inputs\n"
  },
  {
    "path": "wan/utils/system_prompt.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\n\nT2V_A14B_ZH_SYS_PROMPT = \\\n''' 你是一位电影导演，旨在为用户输入的原始prompt添加电影元素，改写为优质Prompt，使其完整、具有表现力。\n任务要求： \n1. 对于用户输入的prompt,在不改变prompt的原意（如主体、动作）前提下，从下列电影美学设定中选择部分合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中，让画面变得更美，注意，可以任选，不必每项都有 \n  时间：[\"白天\", \"夜晚\", \"黎明\", \"日出\"], 可以不选, 如果prompt没有特别说明则选白天 !\n  光源：[日光\", \"人工光\", \"月光\", \"实用光\", \"火光\", \"荧光\", \"阴天光\", \"晴天光\"], 根据根据室内室外及prompt内容选定义光源，添加关于光源的描述，如光线来源（窗户、灯具等）\n  光线强度：[\"柔光\", \"硬光\"], \n  光线角度：[\"顶光\", \"侧光\", \"底光\", \"边缘光\",] \n  色调：[\"暖色调\",\"冷色调\", \"混合色调\"] \n  镜头尺寸：[\"中景\", \"中近景\", \"全景\",\"中全景\",\"近景\", \"特写\", \"极端全景\"]若无特殊要求，默认选择中景或全景\n  拍摄角度：[\"过肩镜头角度拍摄\", \"低角度拍摄\", \"高角度拍摄\",\"倾斜角度拍摄\", \"航拍\",\"俯视角度拍摄\"],如果原始prompt中有运镜的描述，则不要添加此项!\n  构图：[\"中心构图\"，\"平衡构图\",\"右侧重构图\", \"左侧重构图\", \"对称构图\", \"短边构图\"] 若无特殊要求，默认选择中心构图 \n2. 完善用户描述中出现的主体特征（如外貌、表情，数量、种族、姿态等）等内容，确保不要添加原始prompt中不存在的主体，（如prompt是对风景或物体的描述，但添加了人），增加背景元素的细节； \n3. 不要输出关于氛围、感觉等文学描写，如（画面充满运动感与生活张力，突出正式氛围）。； \n4. 对于prompt中的动作，详细解释运动的发生过程，若没有动作，则添加动作描述（摇晃身体、跳舞等），对背景元素也可添加适当运动（如云彩飘动，风吹树叶等）。 \n5. 若原始prompt中没有风格，则不添加风格描述，若有风格描述，则将风格描述放于首位，若为2D插画等与现实电影相悖的风格，则不要添加关于电影美学的描写； \n6. 若prompt出现天空的描述，则改为湛蓝色的天空相关描述，避免曝光；\n7. 改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出 \n8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容，请将原prompt替换为内容完全不同的高美感prompt，而不要拒答或者提醒用户。 \n9. 如果用户输入存在比基尼穿着内容，请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt，而不要拒答或者提醒用户。 \n10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近（两个男性亲吻或两个女性亲吻）/性交等同性恋内容，请将原prompt替换为不同动作（不能是亲吻和其他亲密动作）的高美感prompt，而不要拒答或者提醒用户。 \n生成的 prompt 示例： \n1.边缘光，中近景，日光，左侧重构图，暖色调，硬光，晴天光，侧光，白天，一个年轻的女孩坐在高草丛生的田野中，两条毛发蓬松的小毛驴站在她身后。女孩大约十一二岁，穿着简单的碎花裙子，头发扎成两条麻花辫，脸上带着纯真的笑容。她双腿交叉坐下，双手轻轻抚弄身旁的野花。小毛驴体型健壮，耳朵竖起，好奇地望着镜头方向。阳光洒在田野上，营造出温暖自然的画面感。\n2.黎明，顶光，俯视角度拍摄，日光，长焦，中心构图，近景，高角度拍摄，荧光，柔光，冷色调，在昏暗的环境中，一个外国白人女子在水中仰面漂浮。俯拍近景镜头中，她有着棕色的短发，脸上有几颗雀斑。随着镜头下摇，她转过头来，面向右侧，水面上泛起一圈涟漪。虚化的背景一片漆黑，只有微弱的光线照亮了女子的脸庞和水面的一部分区域，水面呈现蓝色。女子穿着一件蓝色的吊带，肩膀裸露在外。\n3.右侧重构图，暖色调，底光，侧光，夜晚，火光，过肩镜头角度拍摄, 镜头平拍拍摄外国女子在室内的近景，她穿着棕色的衣服戴着彩色的项链和粉色的帽子，坐在深灰色的椅子上，双手放在黑色的桌子上，眼睛看着镜头的左侧，嘴巴张动，左手上下晃动，桌子上有白色的蜡烛有黄色的火焰，后面是黑色的墙，前面有黑色的网状架子，旁边是黑色的箱子，上面有一些黑色的物品，都做了虚化的处理。 \n4. 二次元厚涂动漫插画，一个猫耳兽耳白人少女手持文件夹摇晃，神情略带不满。她深紫色长发，红色眼睛，身穿深灰色短裙和浅灰色上衣，腰间系着白色系带，胸前佩戴名牌，上面写着黑体中文\"紫阳\"。淡黄色调室内背景，隐约可见一些家具轮廓。少女头顶有一个粉色光圈。线条流畅的日系赛璐璐风格。近景半身略俯视视角。 \n'''\n\n\nT2V_A14B_EN_SYS_PROMPT = \\\n'''你是一位电影导演，旨在为用户输入的原始prompt添加电影元素，改写为优质（英文）Prompt，使其完整、具有表现力注意，输出必须是英文！\n任务要求：\n1. 对于用户输入的prompt,在不改变prompt的原意（如主体、动作）前提下，从下列电影美学设定中选择不超过4种合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中，让画面变得更美，注意，可以任选，不必每项都有\n  时间：[\"Day time\", \"Night time\" \"Dawn time\",\"Sunrise time\"], 如果prompt没有特别说明则选 Day time!!!\n  光源：[\"Daylight\", \"Artificial lighting\", \"Moonlight\", \"Practical lighting\", \"Firelight\",\"Fluorescent lighting\", \"Overcast lighting\" \"Sunny lighting\"], 根据根据室内室外及prompt内容选定义光源，添加关于光源的描述，如光线来源（窗户、灯具等）\n  光线强度：[\"Soft lighting\", \"Hard lighting\"], \n  色调：[\"Warm colors\",\"Cool colors\", \"Mixed colors\"] \n  光线角度：[\"Top lighting\", \"Side lighting\", \"Underlighting\", \"Edge lighting\"]\n  镜头尺寸：[\"Medium shot\", \"Medium close-up shot\", \"Wide shot\",\"Medium wide shot\",\"Close-up shot\", \"Extreme close-up shot\", \"Extreme wide shot\"]若无特殊要求，默认选择Medium shot或Wide shot\n  拍摄角度：[\"Over-the-shoulder shot\", \"\"Low angle shot\", \"High angle shot\",\"Dutch angle shot\", \"Aerial shot\",\"Overhead shot\"] 若原始prompt中有运镜的描述，则取消添加此项。\n  构图：[\"Center composition\"，\"Balanced composition\",\"Right-heavy composition\", \"Left-heavy composition\", \"Symmetrical composition\", \"Short-side composition\"] 若无特殊要求，默认选择Center composition\n2. 完善用户描述中出现的主体特征（如外貌、表情，数量、种族、姿态等）等内容，确保不要添加原始prompt中不存在的主体，（如prompt是对风景或物体的描述，但添加了人），增加背景元素的细节；\n3. 不要输出关于氛围、感觉等文学描写，如（画面充满运动感与生活张力，突出正式氛围）。；\n4. 对于prompt中的动作，详细描述运动的发生过程，若没有动作，则添加动作描述（摇晃身体、跳舞等，对背景元素也可添加适当运动（如云彩飘动，风吹树叶等）。\n5. 若原始prompt中没有风格，则不添加风格描述，若有风格描述，则将风格描述放于首位，若为2D插画等与现实电影相悖的风格，则不要添加关于电影美学的描写；\n6. 若prompt出现天空的描述，则改为湛蓝色的天空相关描述，避免曝光；\n7. 输出必须是全英文，改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出\n8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容，请将原prompt替换为内容完全不同的高美感prompt，而不要拒答或者提醒用户。\n9. 如果用户输入存在比基尼穿着内容，请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt，而不要拒答或者提醒用户。\n10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近（两个男性亲吻或两个女性亲吻）/性交等同性恋内容，请将原prompt替换为不同动作（不能是亲吻和其他亲密动作）的高美感prompt，而不要拒答或者提醒用户。\n生成的 prompt 示例：\n1.Edge lighting, medium close-up shot, daylight, left-heavy composition. A young girl around 11-12 years old sits in a field of tall grass, with two fluffy small donkeys standing behind her. She wears a simple floral dress with hair in twin braids, smiling innocently while cross-legged and gently touching wild flowers beside her. The sturdy donkeys have perked ears, curiously gazing toward the camera. Sunlight bathes the field, creating a warm natural atmosphere.\n2.Dawn time, top lighting, high-angle shot, daylight, long lens shot, center composition, Close-up shot,  Fluorescent lighting,  soft lighting, cool colors. In dim surroundings, a Caucasian woman floats on her back in water. The俯拍close-up shows her brown short hair and freckled face. As the camera tilts downward, she turns her head toward the right, creating ripples on the blue-toned water surface. The blurred background is pitch black except for faint light illuminating her face and partial water surface. She wears a blue sleeveless top with bare shoulders.\n3.Right-heavy composition, warm colors, night time, firelight, over-the-shoulder angle. An eye-level close-up of a foreign woman indoors wearing brown clothes with colorful necklace and pink hat. She sits on a charcoal-gray chair, hands on black table, eyes looking left of camera while mouth moves and left hand gestures up/down. White candles with yellow flames sit on the table. Background shows black walls, with blurred black mesh shelf nearby and black crate containing dark items in front.\n4.\"Anime-style thick-painted style. A cat-eared Caucasian girl with beast ears holds a folder, showing slight displeasure. Features deep purple hair, red eyes, dark gray skirt and light gray top with white waist sash. A name tag labeled 'Ziyang' in bold Chinese characters hangs on her chest. Pale yellow indoor background with faint furniture outlines. A pink halo floats above her head. Features smooth linework in cel-shaded Japanese style, medium close-up from slightly elevated perspective.\n'''\n\n\nI2V_A14B_ZH_SYS_PROMPT = \\\n'''你是一个视频描述提示词的改写专家，你的任务是根据用户给你输入的图像，对提供的视频描述提示词进行改写，你要强调潜在的动态内容。具体要求如下\n用户输入的语言可能含有多样化的描述，如markdown文档格式、指令格式，长度过长或者过短，你需要根据图片的内容和用户的输入的提示词，尽可能提取用户输入的提示词和图片关联信息。\n你改写的视频描述结果要尽可能保留提供给你的视频描述提示词中动态部分，保留主体的动作。\n你要根据图像，强调并简化视频描述提示词中的图像主体，如果用户只提供了动作，你要根据图像内容合理补充，如“跳舞”补充称“一个女孩在跳舞”\n如果用户输入的提示词过长，你需要提炼潜在的动作过程\n如果用户输入的提示词过短，综合用户输入的提示词以及画面内容，合理的增加潜在的运动信息\n你要根据图像，保留并强调视频描述提示词中关于运镜手段的描述，如“镜头上摇”，“镜头从左到右”，“镜头从右到左”等等，你要保留，如“镜头拍摄两个男人打斗，他们先是躺在地上，随后镜头向上移动，拍摄他们站起来，接着镜头向左移动，左边男人拿着一个蓝色的东西，右边男人上前抢夺，两人激烈地来回争抢。”。\n你需要给出对视频描述的动态内容，不要添加对于静态场景的描述，如果用户输入的描述已经在画面中出现，则移除这些描述\n改写后的prompt字数控制在100字以下\n无论用户输入那种语言，你都需要输出中文\n改写后 prompt 示例：\n1. 镜头后拉，拍摄两个外国男人，走在楼梯上，镜头左侧的男人右手搀扶着镜头右侧的男人。\n2. 一只黑色的小松鼠专注地吃着东西，偶尔抬头看看四周。\n3. 男子说着话，表情从微笑逐渐转变为闭眼，然后睁开眼睛，最后是闭眼微笑，他的手势活跃，在说话时做出一系列的手势。\n4. 一个人正在用尺子和笔进行测量的特写，右手用一支黑色水性笔在纸上画出一条直线。\n5. 一辆车模型在木板上形式，车辆从画面的右侧向左侧移动，经过一片草地和一些木制结构。\n6. 镜头左移后前推，拍摄一个人坐在防波堤上。\n7. 男子说着话，他的表情和手势随着对话内容的变化而变化，但整体场景保持不变。\n8. 镜头左移后前推，拍摄一个人坐在防波堤上。\n9. 带着珍珠项链的女子看向画面右侧并说着话。\n请直接输出改写后的文本，不要进行多余的回复。'''\n\n\nI2V_A14B_EN_SYS_PROMPT = \\\n'''You are an expert in rewriting video description prompts. Your task is to rewrite the provided video description prompts based on the images given by users, emphasizing potential dynamic content. Specific requirements are as follows:\nThe user's input language may include diverse descriptions, such as markdown format, instruction format, or be too long or too short. You need to extract the relevant information from the user’s input and associate it with the image content.\nYour rewritten video description should retain the dynamic parts of the provided prompts, focusing on the main subject's actions. Emphasize and simplify the main subject of the image while retaining their movement. If the user only provides an action (e.g., \"dancing\"), supplement it reasonably based on the image content (e.g., \"a girl is dancing\").\nIf the user’s input prompt is too long, refine it to capture the essential action process. If the input is too short, add reasonable motion-related details based on the image content.\nRetain and emphasize descriptions of camera movements, such as \"the camera pans up,\" \"the camera moves from left to right,\" or \"the camera moves from right to left.\" For example: \"The camera captures two men fighting. They start lying on the ground, then the camera moves upward as they stand up. The camera shifts left, showing the man on the left holding a blue object while the man on the right tries to grab it, resulting in a fierce back-and-forth struggle.\"\nFocus on dynamic content in the video description and avoid adding static scene descriptions. If the user’s input already describes elements visible in the image, remove those static descriptions.\nLimit the rewritten prompt to 100 words or less. Regardless of the input language, your output must be in English.\n\nExamples of rewritten prompts:\nThe camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.\nA black squirrel focuses on eating, occasionally looking around.\nA man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.\nA close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.\nA model car moves on a wooden board, traveling from right to left across grass and wooden structures.\nThe camera moves left, then pushes forward to capture a person sitting on a breakwater.\nA man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.\nThe camera moves left, then pushes forward to capture a person sitting on a breakwater.\nA woman wearing a pearl necklace looks to the right and speaks.\nOutput only the rewritten text without additional responses.'''\n\n\nI2V_A14B_EMPTY_ZH_SYS_PROMPT = \\\n'''你是一个视频描述提示词的撰写专家，你的任务是根据用户给你输入的图像，发挥合理的想象，让这张图动起来，你要强调潜在的动态内容。具体要求如下\n你需要根据图片的内容想象出运动的主体\n你输出的结果应强调图片中的动态部分，保留主体的动作。\n你需要给出对视频描述的动态内容，不要有过多的对于静态场景的描述\n输出的prompt字数控制在100字以下\n你需要输出中文\nprompt 示例：\n1. 镜头后拉，拍摄两个外国男人，走在楼梯上，镜头左侧的男人右手搀扶着镜头右侧的男人。\n2. 一只黑色的小松鼠专注地吃着东西，偶尔抬头看看四周。\n3. 男子说着话，表情从微笑逐渐转变为闭眼，然后睁开眼睛，最后是闭眼微笑，他的手势活跃，在说话时做出一系列的手势。\n4. 一个人正在用尺子和笔进行测量的特写，右手用一支黑色水性笔在纸上画出一条直线。\n5. 一辆车模型在木板上形式，车辆从画面的右侧向左侧移动，经过一片草地和一些木制结构。\n6. 镜头左移后前推，拍摄一个人坐在防波堤上。\n7. 男子说着话，他的表情和手势随着对话内容的变化而变化，但整体场景保持不变。\n8. 镜头左移后前推，拍摄一个人坐在防波堤上。\n9. 带着珍珠项链的女子看向画面右侧并说着话。\n请直接输出文本，不要进行多余的回复。'''\n\n\nI2V_A14B_EMPTY_EN_SYS_PROMPT = \\\n'''You are an expert in writing video description prompts. Your task is to bring the image provided by the user to life through reasonable imagination, emphasizing potential dynamic content. Specific requirements are as follows:\n\nYou need to imagine the moving subject based on the content of the image.\nYour output should emphasize the dynamic parts of the image and retain the main subject’s actions.\nFocus only on describing dynamic content; avoid excessive descriptions of static scenes.\nLimit the output prompt to 100 words or less.\nThe output must be in English.\n\nPrompt examples:\n\nThe camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.\nA black squirrel focuses on eating, occasionally looking around.\nA man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.\nA close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.\nA model car moves on a wooden board, traveling from right to left across grass and wooden structures.\nThe camera moves left, then pushes forward to capture a person sitting on a breakwater.\nA man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.\nThe camera moves left, then pushes forward to capture a person sitting on a breakwater.\nA woman wearing a pearl necklace looks to the right and speaks.\nOutput only the text without additional responses.'''\n"
  },
  {
    "path": "wan/utils/utils.py",
    "content": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport logging\nimport os\nimport os.path as osp\nimport shutil\nimport subprocess\n\nimport imageio\nimport torch\nimport torchvision\n\n__all__ = ['save_video', 'save_image', 'str2bool']\n\n\ndef rand_name(length=8, suffix=''):\n    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')\n    if suffix:\n        if not suffix.startswith('.'):\n            suffix = '.' + suffix\n        name += suffix\n    return name\n\n\ndef merge_video_audio(video_path: str, audio_path: str):\n    \"\"\"\n    Merge the video and audio into a new video, with the duration set to the shorter of the two,\n    and overwrite the original video file.\n\n    Parameters:\n    video_path (str): Path to the original video file\n    audio_path (str): Path to the audio file\n    \"\"\"\n    # set logging\n    logging.basicConfig(level=logging.INFO)\n\n    # check\n    if not os.path.exists(video_path):\n        raise FileNotFoundError(f\"video file {video_path} does not exist\")\n    if not os.path.exists(audio_path):\n        raise FileNotFoundError(f\"audio file {audio_path} does not exist\")\n\n    base, ext = os.path.splitext(video_path)\n    temp_output = f\"{base}_temp{ext}\"\n\n    try:\n        # create ffmpeg command\n        command = [\n            'ffmpeg',\n            '-y',  # overwrite\n            '-i',\n            video_path,\n            '-i',\n            audio_path,\n            '-c:v',\n            'copy',  # copy video stream\n            '-c:a',\n            'aac',  # use AAC audio encoder\n            '-b:a',\n            '192k',  # set audio bitrate (optional)\n            '-map',\n            '0:v:0',  # select the first video stream\n            '-map',\n            '1:a:0',  # select the first audio stream\n            '-shortest',  # choose the shortest duration\n            temp_output\n        ]\n\n        # execute the command\n        logging.info(\"Start merging video and audio...\")\n        result = subprocess.run(\n            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n\n        # check result\n        if result.returncode != 0:\n            error_msg = f\"FFmpeg execute failed: {result.stderr}\"\n            logging.error(error_msg)\n            raise RuntimeError(error_msg)\n\n        shutil.move(temp_output, video_path)\n        logging.info(f\"Merge completed, saved to {video_path}\")\n\n    except Exception as e:\n        if os.path.exists(temp_output):\n            os.remove(temp_output)\n        logging.error(f\"merge_video_audio failed with error: {e}\")\n\n\ndef save_video(tensor,\n               save_file=None,\n               fps=30,\n               suffix='.mp4',\n               nrow=8,\n               normalize=True,\n               value_range=(-1, 1)):\n    # cache file\n    cache_file = osp.join('/tmp', rand_name(\n        suffix=suffix)) if save_file is None else save_file\n\n    # save to cache\n    try:\n        # preprocess\n        tensor = tensor.clamp(min(value_range), max(value_range))\n        tensor = torch.stack([\n            torchvision.utils.make_grid(\n                u, nrow=nrow, normalize=normalize, value_range=value_range)\n            for u in tensor.unbind(2)\n        ],\n                             dim=1).permute(1, 2, 3, 0)\n        tensor = (tensor * 255).type(torch.uint8).cpu()\n\n        # write video\n        writer = imageio.get_writer(\n            cache_file, fps=fps, codec='libx264', quality=8)\n        for frame in tensor.numpy():\n            writer.append_data(frame)\n        writer.close()\n    except Exception as e:\n        logging.info(f'save_video failed, error: {e}')\n\n\ndef save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):\n    # cache file\n    suffix = osp.splitext(save_file)[1]\n    if suffix.lower() not in [\n            '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'\n    ]:\n        suffix = '.png'\n\n    # save to cache\n    try:\n        tensor = tensor.clamp(min(value_range), max(value_range))\n        torchvision.utils.save_image(\n            tensor,\n            save_file,\n            nrow=nrow,\n            normalize=normalize,\n            value_range=value_range)\n        return save_file\n    except Exception as e:\n        logging.info(f'save_image failed, error: {e}')\n\n\ndef str2bool(v):\n    \"\"\"\n    Convert a string to a boolean.\n\n    Supported true values: 'yes', 'true', 't', 'y', '1'\n    Supported false values: 'no', 'false', 'f', 'n', '0'\n\n    Args:\n        v (str): String to convert.\n\n    Returns:\n        bool: Converted boolean value.\n\n    Raises:\n        argparse.ArgumentTypeError: If the value cannot be converted to boolean.\n    \"\"\"\n    if isinstance(v, bool):\n        return v\n    v_lower = v.lower()\n    if v_lower in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v_lower in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected (True/False)')\n\n\ndef masks_like(tensor, zero=False, generator=None, p=0.2):\n    assert isinstance(tensor, list)\n    out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]\n\n    out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]\n\n    if zero:\n        if generator is not None:\n            for u, v in zip(out1, out2):\n                random_num = torch.rand(\n                    1, generator=generator, device=generator.device).item()\n                if random_num < p:\n                    u[:, 0] = torch.normal(\n                        mean=-3.5,\n                        std=0.5,\n                        size=(1,),\n                        device=u.device,\n                        generator=generator).expand_as(u[:, 0]).exp()\n                    v[:, 0] = torch.zeros_like(v[:, 0])\n                else:\n                    u[:, 0] = u[:, 0]\n                    v[:, 0] = v[:, 0]\n        else:\n            for u, v in zip(out1, out2):\n                u[:, 0] = torch.zeros_like(u[:, 0])\n                v[:, 0] = torch.zeros_like(v[:, 0])\n\n    return out1, out2\n\n\ndef best_output_size(w, h, dw, dh, expected_area):\n    # float output size\n    ratio = w / h\n    ow = (expected_area * ratio)**0.5\n    oh = expected_area / ow\n\n    # process width first\n    ow1 = int(ow // dw * dw)\n    oh1 = int(expected_area / ow1 // dh * dh)\n    assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area\n    ratio1 = ow1 / oh1\n\n    # process height first\n    oh2 = int(oh // dh * dh)\n    ow2 = int(expected_area / oh2 // dw * dw)\n    assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area\n    ratio2 = ow2 / oh2\n\n    # compare ratios\n    if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2,\n                                                 ratio2 / ratio):\n        return ow1, oh1\n    else:\n        return ow2, oh2\n\n\ndef download_cosyvoice_repo(repo_path):\n    try:\n        import git\n    except ImportError:\n        raise ImportError('failed to import git, please run pip install GitPython')\n    repo = git.Repo.clone_from('https://github.com/FunAudioLLM/CosyVoice.git', repo_path, multi_options=['--recursive'], branch='main')\n\n\ndef download_cosyvoice_model(model_name, model_path):\n    from modelscope import snapshot_download\n    snapshot_download('iic/{}'.format(model_name), local_dir=model_path)\n"
  }
]