Showing preview only (1,477K chars total). Download the full file or copy to clipboard to get everything.
Repository: thu-ml/Causal-Forcing
Branch: main
Commit: 9d7fcaf94a54
Files: 154
Total size: 1.4 MB
Directory structure:
gitextract_cbc4vx5e/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── ar_diffusion_tf_chunkwise.yaml
│ ├── ar_diffusion_tf_framewise.yaml
│ ├── causal_cd_chunkwise.yaml
│ ├── causal_cd_framewise.yaml
│ ├── causal_forcing_dmd_chunkwise.yaml
│ ├── causal_forcing_dmd_framewise.yaml
│ ├── causal_forcing_dmd_framewise_1step.yaml
│ ├── causal_forcing_dmd_framewise_2step.yaml
│ ├── causal_ode_chunkwise.yaml
│ ├── causal_ode_framewise.yaml
│ └── default_config.yaml
├── demo.py
├── demo_utils/
│ ├── constant.py
│ ├── memory.py
│ ├── taehv.py
│ ├── utils.py
│ ├── vae.py
│ ├── vae_block3.py
│ └── vae_torch2trt.py
├── get_causal_ode_data_chunkwise.py
├── get_causal_ode_data_framewise.py
├── get_causal_ode_data_kv_optimized.py
├── inference.py
├── long_video/
│ ├── LICENSE
│ ├── README.md
│ ├── app.py
│ ├── configs/
│ │ ├── default_config.yaml
│ │ └── rolling_forcing_dmd.yaml
│ ├── inference.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── causvid.py
│ │ ├── diffusion.py
│ │ ├── dmd.py
│ │ ├── gan.py
│ │ ├── ode_regression.py
│ │ └── sid.py
│ ├── pipeline/
│ │ ├── __init__.py
│ │ ├── bidirectional_diffusion_inference.py
│ │ ├── bidirectional_inference.py
│ │ ├── causal_diffusion_inference.py
│ │ ├── rolling_forcing_inference.py
│ │ └── rolling_forcing_training.py
│ ├── prompts/
│ │ └── example_prompts.txt
│ ├── requirements.txt
│ ├── train.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── diffusion.py
│ │ ├── distillation.py
│ │ ├── gan.py
│ │ └── ode.py
│ ├── utils/
│ │ ├── dataset.py
│ │ ├── distributed.py
│ │ ├── lmdb.py
│ │ ├── loss.py
│ │ ├── misc.py
│ │ ├── scheduler.py
│ │ └── wan_wrapper.py
│ └── wan/
│ ├── README.md
│ ├── __init__.py
│ ├── configs/
│ │ ├── __init__.py
│ │ ├── shared_config.py
│ │ ├── wan_i2v_14B.py
│ │ ├── wan_t2v_14B.py
│ │ └── wan_t2v_1_3B.py
│ ├── distributed/
│ │ ├── __init__.py
│ │ ├── fsdp.py
│ │ └── xdit_context_parallel.py
│ ├── image2video.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── causal_model.py
│ │ ├── clip.py
│ │ ├── model.py
│ │ ├── t5.py
│ │ ├── tokenizers.py
│ │ ├── vae.py
│ │ └── xlm_roberta.py
│ ├── text2video.py
│ └── utils/
│ ├── __init__.py
│ ├── fm_solvers.py
│ ├── fm_solvers_unipc.py
│ ├── prompt_extend.py
│ ├── qwen_vl_utils.py
│ └── utils.py
├── model/
│ ├── __init__.py
│ ├── base.py
│ ├── causvid.py
│ ├── diffusion.py
│ ├── dmd.py
│ ├── gan.py
│ ├── naive_consistency.py
│ ├── ode_regression.py
│ └── sid.py
├── pipeline/
│ ├── __init__.py
│ ├── bidirectional_diffusion_inference.py
│ ├── bidirectional_inference.py
│ ├── bidirectional_training.py
│ ├── causal_diffusion_inference.py
│ ├── causal_inference.py
│ ├── self_forcing_training.py
│ └── teacher_forcing_training.py
├── prompts/
│ ├── demos.txt
│ └── i2v/
│ └── target_crop_info_26-15.json
├── requirements.txt
├── setup.py
├── train.py
├── trainer/
│ ├── __init__.py
│ ├── diffusion.py
│ ├── distillation.py
│ ├── gan.py
│ ├── naive_cd.py
│ └── ode.py
├── utils/
│ ├── create_lmdb_iterative.py
│ ├── dataset.py
│ ├── distributed.py
│ ├── lmdb_.py
│ ├── loss.py
│ ├── merge_and_get_clean.py
│ ├── merge_lmdb.py
│ ├── misc.py
│ ├── ode_generation.py
│ ├── scheduler.py
│ └── wan_wrapper.py
└── wan/
├── README.md
├── __init__.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── causal_model.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae.py
│ └── xlm_roberta.py
├── text2video.py
└── utils/
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__
*.egg-info
wan_models
checkpoints
output
dataset
prompts/vidprom_filtered_extended.txt
logs
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
## Causal Forcing & Causal Forcing++
### Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation
<p align="center">
<p align="center">
<div>
<a href="https://zhuhz22.github.io/" target="_blank">Hongzhou Zhu*</a><sup></sup>,
<a href="https://gracezhao1997.github.io/" target="_blank">Min Zhao*</a><sup></sup> ,
<a href="https://guandehe.github.io/" target="_blank">Guande He</a><sup></sup>,
<a href="https://scholar.google.com/citations?user=dxN1_X0AAAAJ&hl=en" target="_blank">Hang Su</a><sup></sup>,
<a href="https://zhenxuan00.github.io/" target="_blank">Chongxuan Li</a><sup></sup> ,
<a href="https://ml.cs.tsinghua.edu.cn/~jun/index.shtml" target="_blank">Jun Zhu</a><sup></sup>
</div>
<div>
<sup></sup>Tsinghua University & Shengshu & UT Austin & RUC
</div>
</div>
</p>
<h3 align="center"><a href="https://arxiv.org/abs/2602.02214">Causal Forcing</a> | <a href="https://arxiv.org/abs/2605.15141">Causal Forcing++</a> | <a href="https://thu-ml.github.io/CausalForcing.github.io">Website</a> | <a href="https://huggingface.co/zhuhz22/Causal-Forcing/tree/main">Models</a> | <a href="assets/wechat.jpg">WeChat</a> | <a href="https://my.feishu.cn/wiki/AjBSwcjpqiN0ECkodIWcGDcMn4e?from=from_copylink">Document</a> </h3>
</p>
-----
The Causal Forcing series uses **Causal ODE** or **Causal Consistency Distillation** to drive asymmetric DMD as a theoretically correct initialization for real-time interactive video generation.
[Causal Forcing](https://arxiv.org/abs/2602.02214) significantly outperforms Self Forcing in **both visual quality and motion dynamics**, while keeping **the same training budget and inference efficiency**. We support both chunk-wise and **frame-wise** models, with the latter natively unifying T2V and **I2V**.
We further propose [**Causal Forcing++**](https://arxiv.org/abs/2605.15141), replacing ODE with **causal Consistency Distillation** to eliminate ODE data curation and improve performance, releasing the first **1-step/2-step frame-wise** models.
-----
<img width="2090" height="850" alt="overview" src="assets/pipeline.png" />
## Table of Contents
- [Quick Start](#quick-start)
- [Installation](#installation)
- [Inference: 1/2/4-step T2V & I2V](#cli-inference)
- [Long Video](#minute-level-long-video-generation)
- [Training Pipeline](#training) (Causal Foricng / Causal Forcing++)
- [Stage 1: AR Diffusion](#training)
- [Stage 2: Causal ODE (Causal Forcing)](#training) / [🔥Causal CD (Causal Forcing++)](#-stage-2-option-b-causal-consistency-distillation-initialization-causal-forcing)
- [Stage 3: Asymmetric DMD](#stage-3-dmd)
| Models | Checkpoints | Description|
|--------------------|---------------------------------------------------------------------------------------------------------------------------------------------|-------|
| Chunk-wise 4-step | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/chunkwise/causal_forcing.pt) | SOTA AR model on Wan1.3B outperforming Self Forcing|
| Frame-wise 4-step | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/framewise/causal_forcing.pt) | Rich dynamics and high quality |
| Frame-wise 2-step🔥 | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/causal-forcing%2B%2B/framewise-2step.pt) | The first frame-wise 2-step model **even better than 4-step**|
| Frame-wise 1-step | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/causal-forcing%2B%2B/framewise-1step.pt) | Extremely low latency |
| Long Video Generator | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/chunkwise/longvideo.pt) | Minute-long AR video generator |
| HY1.5-TI2V (8B) | 🤗 [Huggingface](https://huggingface.co/MIN-Lab/minWM) | Refer to [this repo](https://github.com/shengshu-ai/minWM). |
| Action-conditioned WM | 🤗 [Huggingface](https://huggingface.co/MIN-Lab/minWM) | Refer to [this repo](https://github.com/shengshu-ai/minWM). |
-----
https://github.com/user-attachments/assets/310f0cfa-e1bb-496d-8941-87f77b3271c0
## 🔥 News
- **2026.5.15**: We release [Causal Forcing++](https://arxiv.org/abs/2605.15141), supporting Casual Consistency Distillation for few-step initialization, and open-source **the first frame-wise 2-step AR model** comparable to chunk-wise 4-step models!
- **2026.5.10**: Thanks to @[AshadowZ](https://github.com/AshadowZ), now our chunk-wise ODE data curation is **3x faster**!
- **2026.4.16**: **Optimize the Stage 2 🔥consistency distillation🔥 infrastructure for 3× faster training, let's try it now!** We have also released the ckpt.
- **2026.3.15** : [Rolling Sink](https://github.com/haodong2000/RollingSink), [Infinity-RoPE](https://github.com/yesiltepe-hidir/infinity-rope) and [Deep Forcing](https://cvlab-kaist.github.io/DeepForcing/) adopt Causal Forcing as one of the base models!
- **2026.2.28** : Add [FAQ section](#faq--blog) regarding hot topics, specifically which is the better Initialization between AR diffusion and causal ODE distillation.
- **2026.2.11** : We now support **I2V** generation! Feel free to try it [here](#new-i2v)!
- **2026.2.7** : Causal Forcing now supports [Rolling Forcing](https://github.com/TencentARC/RollingForcing), enabling minute-level long video generation!
- **2026.2.5** : Release causal consistency distillation (Preview) as substitute for ODE distillation, **free of generating ODE paired data**!
- **2026.2.2** : The [paper](https://arxiv.org/abs/2602.02214), [project page](https://thu-ml.github.io/CausalForcing.github.io/), and code are released.
## Quick Start
> The inference environment is identical to Self Forcing.
**NOTE**: Similar to CausVid/Self Forcing, Causal Forcing does not natively support videos longer than 81 frames. As a base training method, it is orthogonal to techniques like Longlive/Rolling Forcing. To use Causal Forcing as a long video baseline, see [this extension](#minute-level-long-video-generation). **Directly using the 5-second trained Causal Forcing model as a baseline for long video generation is extremely unfair**.
### Installation
```bash
conda create -n causal_forcing python=3.10 -y
conda activate causal_forcing
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
pip install flash-attn --no-build-isolation
python setup.py develop
```
### Download Checkpoints
```bash
hf download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
hf download Wan-AI/Wan2.1-T2V-14B --local-dir wan_models/Wan2.1-T2V-14B
# Causal Forcing
hf download zhuhz22/Causal-Forcing chunkwise/causal_forcing.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing framewise/causal_forcing.pt --local-dir checkpoints
# Causal Forcing++
hf download zhuhz22/Causal-Forcing causal-forcing++/framewise-2step.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing causal-forcing++/framewise-1step.pt --local-dir checkpoints
```
### CLI Inference
#### T2V
Chunk-wise model:
```bash
python inference.py \
--config_path configs/causal_forcing_dmd_chunkwise.yaml \
--output_folder output/chunkwise \
--checkpoint_path checkpoints/chunkwise/causal_forcing.pt \
--data_path prompts/demos.txt
```
Frame-wise model:
```bash
# =============== Causal Forcing++ ================
# 2-step Causal Forcing++
python inference.py \
--config_path configs/causal_forcing_dmd_framewise_2step.yaml \
--output_folder output/framewise_2step_cf++ \
--checkpoint_path checkpoints/causal-forcing++/framewise-2step.pt \
--data_path prompts/demos.txt \
--use_ema
# 1-step Causal Forcing++
python inference.py \
--config_path configs/causal_forcing_dmd_framewise_1step.yaml \
--output_folder output/framewise_1step_cf++ \
--checkpoint_path checkpoints/causal-forcing++/framewise-1step.pt \
--data_path prompts/demos.txt \
--use_ema
# =============== Causal Forcing ================
# 4-step Causal Forcing
python inference.py \
--config_path configs/causal_forcing_dmd_framewise.yaml \
--output_folder output/framewise \
--checkpoint_path checkpoints/framewise/causal_forcing.pt \
--data_path prompts/demos.txt \
--use_ema
```
#### I2V
> Our frame-wise setting natively supports I2V. You simply need to set the first latent initial frame as your conditional image.
```bash
python inference.py \
--config_path configs/causal_forcing_dmd_framewise.yaml \
--output_folder output/framewise \
--checkpoint_path checkpoints/framewise/causal_forcing.pt \
--data_path prompts/i2v \
--i2v \
--use_ema
```
### Minute-level Long Video Generation
Built on [Rolling Forcing](https://github.com/TencentARC/RollingForcing), we implemented minute-level long video generation. See [here](./long_video) for the detail.
[Infinity-RoPE](https://github.com/yesiltepe-hidir/infinity-rope), [Deep Forcing](https://cvlab-kaist.github.io/DeepForcing/) and [Rolling Sink](https://github.com/haodong2000/RollingSink) also adopt Causal Forcing as one of their base models, enabling interactive (prompt-switchable) long video generation at the minute scale. You can also try them out at their repos.
## Training
<details>
<summary> Stage 1: Autoregressive Diffusion Training (Can skip by using our pretrained checkpoints. Click to expand.)</summary>
First download the dataset (we provide a 6K toy dataset here):
```bash
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
python utils/merge_and_get_clean.py
```
> If the download gets stuck, Ctrl^C and then resume it.
> For training on your own dataset, refer to [this issue](https://github.com/thu-ml/Causal-Forcing/issues/8).
Then train the AR-diffusion model:
- Framewise:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/ar_diffusion_tf_framewise.yaml \
--logdir logs/ar_diffusion_framewise
```
- Chunkwise:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/ar_diffusion_tf_chunkwise.yaml \
--logdir logs/ar_diffusion_chunkwise
```
> We recommend training no less than 2K steps, and more steps (e.g., 5~10K) will lead to better performance.
Inference to test training results:
```bash
python inference.py \
--config_path configs/ar_diffusion_tf_{framewise OR chunkwise}.yaml \
--output_folder output/{framewise OR chunkwise}_ar_diffusion \
--checkpoint_path checkpoints/{framewise OR chunkwise}/ar_diffusion.pt \
--data_path prompts/demos.txt
```
</details>
<details>
<summary> Stage 2 (Option a): Causal ODE Initialization (Can skip by using our pretrained checkpoints. Click to expand.)</summary>
🔥 Thanks to @[AshadowZ](https://github.com/AshadowZ), now our chunk-wise ODE data curation is **3x faster**!
You can also use `bf16` to accelerate generation.
If you have skipped Stage 1, you need to download the pretrained models:
```bash
hf download zhuhz22/Causal-Forcing framewise/ar_diffusion.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/ar_diffusion.pt --local-dir checkpoints
```
In this stage, first generate ODE paired data:
```bash
# for the frame-wise model
torchrun --nproc_per_node=8 \
get_causal_ode_data_framewise.py \
--generator_ckpt checkpoints/framewise/ar_diffusion.pt \
--rawdata_path dataset/clean_data \
--output_folder dataset/ODE6KCausal_framewise_latents
python utils/create_lmdb_iterative.py \
--data_path dataset/ODE6KCausal_framewise_latents \
--lmdb_path dataset/ODE6KCausal_framewise
# for the chunk-wise model
torchrun --nproc_per_node=8 \
get_causal_ode_data_chunkwise.py \
--generator_ckpt checkpoints/chunkwise/ar_diffusion.pt \
--rawdata_path dataset/clean_data \
--output_folder dataset/ODE6KCausal_chunkwise_latents
python utils/create_lmdb_iterative.py \
--data_path dataset/ODE6KCausal_chunkwise_latents \
--lmdb_path dataset/ODE6KCausal_chunkwise
#🔥NEW: Or you can use the optimized code (3x speedup) by @AshadowZ
torchrun --nproc_per_node=8 \
get_causal_ode_data_kv.py \
--generator_ckpt checkpoints/{chunkwise,framewise}/ar_diffusion.pt \
--rawdata_path /mnt/vepfs/base2/zhaomin/group_0001_lmdb \
--output_folder dataset/ODE6KCausal_{chunkwise,framewise}_latents \
--num_frames_per_chunk 3/1 # 3 for chunkwise, 1 for framewise
--generation_mode blockwise_kv
```
Or you can also directly download our prepared dataset (~300G):
```bash
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
python utils/merge_lmdb.py
```
> If the download gets stuck, Ctrl^C and then resume it.
And then train ODE initialization models:
- Frame-wise:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_ode_framewise.yaml \
--logdir logs/causal_ode_framewise
```
- Chunk-wise:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_ode_chunkwise.yaml \
--logdir logs/causal_ode_chunkwise
```
> We recommend training no less than 1K steps, and more steps (e.g., 5~10K) will lead to better performance.
Inference to test training results:
The same as [here](#cli-inference).
</details>
### 🔥 Stage 2 (Option b): Causal Consistency Distillation Initialization ([Causal Forcing++](https://arxiv.org/abs/2605.15141))
Since creating ODE-paired data is very time-consuming, we also provide an alternative here that achieves the same effect (or better) as ODE distillation while requiring only ground-truth data, **free of generating ODE data!**
> Thanks to [@chijw's effort](https://github.com/thu-ml/Causal-Forcing/pull/20), now the EMA mechanism is more efficient!
- Frame-wise:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_cd_framewise.yaml \
--logdir logs/causal_cd_framewise
```
- Chunk-wise:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_cd_chunkwise.yaml \
--logdir logs/causal_cd_chunkwise
```
> We recommend training no less than 3K steps, and more steps (e.g., 5~10K) will lead to better performance.
<details>
<summary>
You can also download the checkpoints directly (click to expand) :
</summary>
```bash
hf download zhuhz22/Causal-Forcing framewise/causal_cd.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_cd.pt --local-dir checkpoints
```
</details>
Inference to test training results: the same as [here](#cli-inference).
### Stage 3: DMD
First download the dataset:
```bash
hf download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
```
<details>
<summary>
If you have skipped Stage 2, you need to download our pretrained checkpoints (click to expand):
</summary>
```bash
hf download zhuhz22/Causal-Forcing framewise/causal_ode.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_ode.pt --local-dir checkpoints
# 🔥 Or you can also use the improved version (Causal Forcing++): causal CD instead of causal ODE
hf download zhuhz22/Causal-Forcing framewise/causal_cd.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_cd.pt --local-dir checkpoints
```
</details>
And then train DMD models:
- Frame-wise model:
```bash
# =============== Causal Forcing ================
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_forcing_dmd_framewise.yaml \
--logdir logs/causal_forcing_dmd_framewise
# =============== Causal Forcing++ ================
# 2-step Causal Forcing++
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_forcing_dmd_framewise_2step.yaml \
--logdir logs/causal_forcing_dmd_framewise
# 1-step Causal Forcing++
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_forcing_dmd_framewise_1step.yaml \
--logdir logs/causal_forcing_dmd_framewise
```
- Chunk-wise model:
```bash
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint $MASTER_ADDR \
train.py \
--config_path configs/causal_forcing_dmd_chunkwise.yaml \
--logdir logs/causal_forcing_dmd_chunkwise
```
> For bs64, we recommend training for no more than 1K steps; otherwise the motion dynamics may degrade. If bs is smaller (e.g., 8), more training steps is preferable.
Such models are the final models used to generate videos.
## FAQ & Blog
See the [FAQ](https://my.feishu.cn/wiki/AjBSwcjpqiN0ECkodIWcGDcMn4e) and the [blog](https://zhuanlan.zhihu.com/p/2002114039493461457). (currently in Chinese)
<details>
<summary> Typical Questions (click to expand)
</summary>
**Why using bidirectional teacher in the DMD stage ?**
- Q: In the DMD stage, do you still use a bidirectional teacher? Why not an AR teacher?
- A: Yes. DMD only requires the student to match the teacher’s final distribution, not the generation trajectory, so a bidirectional teacher is fine. Also, bidirectional diffusion models are typically stronger than AR diffusion, so they make a better teacher.
- Q: Then why must the ODE (or Consistency Distillation) stage use an AR teacher?
- A: Because ODE/CD requires the student and teacher to follow the same trajectory, so their structures must be matched; an AR student cannot be trajectory-aligned with a bidirectional teacher.
🔥🔥 **ODE initialization or multi-step AR diffusion initialization ?**
- Q: Which is better as initialization: a “proper” ODE initialization or directly using multi-step AR diffusion?
- A: We compared this in the Appendix C2. Overall, proper ODE init is better: multi-step AR diffusion init + DMD occasionally yields grid-like or waxy/greasy results. A key reason is that DMD is inherently few-step, so the right comparison is under few-step; in that regime, a few-step diffusion teacher is much weaker than an ODE-distilled teacher. Without ODE distillation, DMD must both close the step gap and handle an added conditioning gap from self-rollout: early few-step errors corrupt the history and get amplified across frames (large exposure bias), which increases DMD pressure. It can still converge, but typically with worse quality than ODE initialization. Also, with ODE init, DMD can be trained very few steps (e.g., ~100), reducing the risk of dynamics degradation from long DMD training.
**Can frame-level non-injectivity appears in the actual training dataset ?**
- Q: Regarding the “one-to-many” analysis in the ODE stage: since a single frame’s latent has very high dimensionality, isn’t the probability of being exactly identical extremely small?
- A: Yes, but the key point here is not whether the dataset literally contains identical samples; it’s whether there exists a well-defined function in the mathematical sense. Our vision modalities live in a continuous space—even in 1D, getting two samples to be exactly identical is extremely unlikely. However, the theoretical existence of exact collisions is enough to break the function property and make it ill-defined.
</details>
## Acknowledgements
- This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid), [Self Forcing](https://github.com/guandeh17/Self-Forcing), [Rolling Forcing](https://github.com/TencentARC/RollingForcing) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
- Thanks to @[chijw](https://github.com/chijw) for improving the EMA mechanism.
- Thanks to @[AshadowZ](https://github.com/AshadowZ) for improving the causal ODE data curation efficiency.
- Causal Forcing++ with 1/2 steps applies the first-frame 4-step technique proposed by [ASD](https://github.com/BigAandSmallq/SAD). We thank ASD for its contribution.
## References
If you find the method useful, please cite
```
@article{zhu2026causal,
title={Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation},
author={Zhu, Hongzhou and Zhao, Min and He, Guande and Su, Hang and Li, Chongxuan and Zhu, Jun},
journal={arXiv preprint arXiv:2602.02214},
year={2026}
}
@article{zhao2026causal,
title={Causal Forcing++: Scalable Few-Step Autoregressive Diffusion Distillation for Real-Time Interactive Video Generation},
author={Zhao, Min and Zhu, Hongzhou and Zheng, Kaiwen and Zhou, Zihan and Yan, Bokai and Li, Xinyuan and Yang, Xiao and Li, Chongxuan and Zhu, Jun},
journal={arXiv preprint arXiv:2605.15141},
year={2026}
}
```
================================================
FILE: configs/ar_diffusion_tf_chunkwise.yaml
================================================
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
warp_denoising_step: true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 2000000000
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: diffusion
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: dataset/clean_data
teacher_forcing: true # TF/DF
================================================
FILE: configs/ar_diffusion_tf_framewise.yaml
================================================
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
warp_denoising_step: true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 2000000000
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: diffusion
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: dataset/clean_data
teacher_forcing: true # TF/DF
================================================
FILE: configs/causal_cd_chunkwise.yaml
================================================
generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: consistency_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: dataset/clean_data
grad_accum: 1
spatial_sf: true
teacher_forcing: True
causal: True
is_causal: True
discrete_cd_N: 48
================================================
FILE: configs/causal_cd_framewise.yaml
================================================
generator_ckpt: checkpoints/framewise/ar_diffusion.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: consistency_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: dataset/clean_data
grad_accum: 1
spatial_sf: true
teacher_forcing: True
causal: True
is_causal: True
discrete_cd_N: 48
================================================
FILE: configs/causal_forcing_dmd_chunkwise.yaml
================================================
generator_ckpt: checkpoints/chunkwise/causal_ode.pt # 🔥 or checkpoints/chunkwise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt
================================================
FILE: configs/causal_forcing_dmd_framewise.yaml
================================================
generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt
================================================
FILE: configs/causal_forcing_dmd_framewise_1step.yaml
================================================
generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
denoising_step_list_first_chunk: # Causal Forcing++ with 1/2 steps applies the first-frame 4-step technique proposed by [ASD](https://github.com/BigAandSmallq/SAD). We thank ASD for its contribution.
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt
================================================
FILE: configs/causal_forcing_dmd_framewise_2step.yaml
================================================
generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 500
denoising_step_list_first_chunk: # Causal Forcing++ with 1/2 steps applies the first-frame 4-step technique proposed by [ASD](https://github.com/BigAandSmallq/SAD). We thank ASD for its contribution.
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt
================================================
FILE: configs/causal_ode_chunkwise.yaml
================================================
generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt
generator_grad:
model: true
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true
generator_task: causal_video
generator_fsdp_wrap_strategy: size
text_encoder_fsdp_wrap_strategy: size
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
wandb_name: ode
sharding_strategy: hybrid_full
lr: 2.0e-06
beta1: 0.9
beta2: 0.999
data_path: dataset/ODE6KCausal_chunkwise
batch_size: 1
log_iters: 1000
trainer: ode
gradient_checkpointing: true
num_frame_per_block: 3
model_kwargs:
timestep_shift: 5.0
================================================
FILE: configs/causal_ode_framewise.yaml
================================================
generator_ckpt: checkpoints/framewise/ar_diffusion.pt
generator_grad:
model: true
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true
generator_task: causal_video
generator_fsdp_wrap_strategy: size
text_encoder_fsdp_wrap_strategy: size
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
wandb_name: ode
sharding_strategy: hybrid_full
lr: 2.0e-06
beta1: 0.9
beta2: 0.999
data_path: dataset/ODE6KCausal_framewise
batch_size: 1
log_iters: 1000
trainer: ode
gradient_checkpointing: true
num_frame_per_block: 1
model_kwargs:
timestep_shift: 5.0
================================================
FILE: configs/default_config.yaml
================================================
independent_first_frame: false
warp_denoising_step: false
weight_decay: 0.01
same_step_across_blocks: true
discriminator_lr_multiplier: 1.0
last_step_only: false
i2v: false
num_training_frames: 21
gc_interval: 100
context_noise: 0
causal: true
ckpt_step: 0
prompt_name: MovieGenVideoBench
prompt_path: prompts/MovieGenVideoBench.txt
eval_first_n: 64
num_samples: 1
height: 480
width: 832
num_frames: 81
================================================
FILE: demo.py
================================================
"""
Demo for Causal-Forcing.
"""
import os
import re
import random
import time
import base64
import argparse
import hashlib
import subprocess
import urllib.request
from io import BytesIO
from PIL import Image
import numpy as np
import torch
from omegaconf import OmegaConf
from flask import Flask, render_template, jsonify
from flask_socketio import SocketIO, emit
import queue
from threading import Thread, Event
from pipeline import CausalInferencePipeline
from demo_utils.constant import ZERO_VAE_CACHE
from demo_utils.vae_block3 import VAEDecoderWrapper
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
from demo_utils.utils import generate_timestamp
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=5001)
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument("--checkpoint_path", type=str, default='.checkpoints/framewise/causal_forcing.pt')
parser.add_argument("--config_path", type=str, default='.configs/causal_forcing_dmd_framewise.yaml')
parser.add_argument('--trt', action='store_true')
args = parser.parse_args()
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
low_memory = get_cuda_free_memory_gb(gpu) < 40
# Load models
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
text_encoder = WanTextEncoder()
# Global variables for dynamic model switching
current_vae_decoder = None
current_use_taehv = False
fp8_applied = False
torch_compile_applied = False
global frame_number
frame_number = 0
anim_name = ""
frame_rate = 6
def initialize_vae_decoder(use_taehv=False, use_trt=False):
"""Initialize VAE decoder based on the selected option"""
global current_vae_decoder, current_use_taehv
if use_trt:
from demo_utils.vae import VAETRTWrapper
current_vae_decoder = VAETRTWrapper()
return current_vae_decoder
if use_taehv:
from demo_utils.taehv import TAEHV
# Check if taew2_1.pth exists in checkpoints folder, download if missing
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
if not os.path.exists(taehv_checkpoint_path):
print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
os.makedirs("checkpoints", exist_ok=True)
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
try:
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
except Exception as e:
print(f"Failed to download taew2_1.pth: {e}")
raise
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEHVDiffusersWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.float16
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
self.config = DotDict(scaling_factor=1.0)
def decode(self, latents, return_dict=None):
# n, c, t, h, w = latents.shape
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
current_vae_decoder = TAEHVDiffusersWrapper()
else:
current_vae_decoder = VAEDecoderWrapper()
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
if 'decoder.' in key or 'conv2' in key:
decoder_state_dict[key] = value
current_vae_decoder.load_state_dict(decoder_state_dict)
current_vae_decoder.eval()
current_vae_decoder.to(dtype=torch.float16)
current_vae_decoder.requires_grad_(False)
current_vae_decoder.to(gpu)
current_use_taehv = use_taehv
print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
return current_vae_decoder
# Initialize with default VAE
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
transformer = WanDiffusionWrapper(is_causal=True)
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
transformer.load_state_dict(state_dict['generator_ema'])
text_encoder.eval()
transformer.eval()
transformer.to(dtype=torch.float16)
text_encoder.to(dtype=torch.bfloat16)
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)
pipeline = CausalInferencePipeline(
config,
device=gpu,
generator=transformer,
text_encoder=text_encoder,
vae=vae_decoder
)
if low_memory:
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
else:
text_encoder.to(gpu)
transformer.to(gpu)
# Flask and SocketIO setup
app = Flask(__name__)
app.config['SECRET_KEY'] = 'frontend_buffered_demo'
socketio = SocketIO(app, cors_allowed_origins="*")
generation_active = False
stop_event = Event()
frame_send_queue = queue.Queue()
sender_thread = None
models_compiled = False
def tensor_to_base64_frame(frame_tensor):
"""Convert a single frame tensor to base64 image string."""
global frame_number, anim_name
# Clamp and normalize to 0-255
frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
frame = frame.to(torch.uint8).cpu().numpy()
# CHW -> HWC
if len(frame.shape) == 3:
frame = np.transpose(frame, (1, 2, 0))
# Convert to PIL Image
if frame.shape[2] == 3: # RGB
image = Image.fromarray(frame, 'RGB')
else: # Handle other formats
image = Image.fromarray(frame)
# Convert to base64
buffer = BytesIO()
image.save(buffer, format='JPEG', quality=100)
if not os.path.exists("./images/%s" % anim_name):
os.makedirs("./images/%s" % anim_name)
frame_number += 1
image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def frame_sender_worker():
"""Background thread that processes frame send queue non-blocking."""
global frame_send_queue, generation_active, stop_event
print("📡 Frame sender thread started")
while True:
frame_data = None
try:
# Get frame data from queue
frame_data = frame_send_queue.get(timeout=1.0)
if frame_data is None: # Shutdown signal
frame_send_queue.task_done() # Mark shutdown signal as done
break
frame_tensor, frame_index, block_index, job_id = frame_data
# Convert tensor to base64
base64_frame = tensor_to_base64_frame(frame_tensor)
# Send via SocketIO
try:
socketio.emit('frame_ready', {
'data': base64_frame,
'frame_index': frame_index,
'block_index': block_index,
'job_id': job_id
})
except Exception as e:
print(f"⚠️ Failed to send frame {frame_index}: {e}")
frame_send_queue.task_done()
except queue.Empty:
# Check if we should continue running
if not generation_active and frame_send_queue.empty():
break
except Exception as e:
print(f"❌ Frame sender error: {e}")
# Make sure to mark task as done even if there's an error
if frame_data is not None:
try:
frame_send_queue.task_done()
except Exception as e:
print(f"❌ Failed to mark frame task as done: {e}")
break
print("📡 Frame sender thread stopped")
@torch.no_grad()
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
"""Generate video and push frames immediately to frontend."""
global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name
try:
generation_active = True
stop_event.clear()
job_id = generate_timestamp()
# Start frame sender thread if not already running
if sender_thread is None or not sender_thread.is_alive():
sender_thread = Thread(target=frame_sender_worker, daemon=True)
sender_thread.start()
# Emit progress updates
def emit_progress(message, progress):
try:
socketio.emit('progress', {
'message': message,
'progress': progress,
'job_id': job_id
})
except Exception as e:
print(f"❌ Failed to emit progress: {e}")
emit_progress('Starting generation...', 0)
# Handle VAE decoder switching
if use_taehv != current_use_taehv:
emit_progress('Switching VAE decoder...', 2)
print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
# Update pipeline with new VAE decoder
pipeline.vae = current_vae_decoder
# Handle FP8 quantization
if enable_fp8 and not fp8_applied:
emit_progress('Applying FP8 quantization...', 3)
print("🔧 Applying FP8 quantization to transformer")
from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
fp8_applied = True
# Text encoding
emit_progress('Encoding text prompt...', 8)
conditional_dict = text_encoder(text_prompts=[prompt])
for key, value in conditional_dict.items():
conditional_dict[key] = value.to(dtype=torch.float16)
if low_memory:
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
move_model_to_device_with_memory_preservation(
text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
# Handle torch.compile if enabled
torch_compile_applied = enable_torch_compile
if enable_torch_compile and not models_compiled:
# Compile transformer and decoder
transformer.compile(mode="max-autotune-no-cudagraphs")
if not current_use_taehv and not low_memory and not args.trt:
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
# Initialize generation
emit_progress('Initializing generation...', 12)
rnd = torch.Generator(gpu).manual_seed(seed)
# all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)
pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
# Generation parameters
num_blocks = 7
current_start_frame = 0
num_input_frames = 0
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
if current_use_taehv:
vae_cache = None
else:
vae_cache = ZERO_VAE_CACHE
for i in range(len(vae_cache)):
vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
total_frames_sent = 0
generation_start_time = time.time()
emit_progress('Generating frames... (frontend handles timing)', 15)
for idx, current_num_frames in enumerate(all_num_frames):
if not generation_active or stop_event.is_set():
break
progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
# Special message for first block with torch.compile
if idx == 0 and torch_compile_applied and not models_compiled:
emit_progress(
f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
models_compiled = True
else:
emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")
block_start_time = time.time()
noisy_input = noise[:, current_start_frame -
num_input_frames:current_start_frame + current_num_frames - num_input_frames]
# Denoising loop
denoising_start = time.time()
for index, current_timestep in enumerate(pipeline.denoising_step_list):
if not generation_active or stop_event.is_set():
break
timestep = torch.ones([1, current_num_frames], device=noise.device,
dtype=torch.int64) * current_timestep
if index < len(pipeline.denoising_step_list) - 1:
_, denoised_pred = transformer(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=current_start_frame * pipeline.frame_seq_length
)
next_timestep = pipeline.denoising_step_list[index + 1]
noisy_input = pipeline.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
).unflatten(0, denoised_pred.shape[:2])
else:
_, denoised_pred = transformer(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=current_start_frame * pipeline.frame_seq_length
)
if not generation_active or stop_event.is_set():
break
denoising_time = time.time() - denoising_start
print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")
# Record output
# all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
# Update KV cache for next block
if idx != len(all_num_frames) - 1:
transformer(
noisy_image_or_video=denoised_pred,
conditional_dict=conditional_dict,
timestep=torch.zeros_like(timestep),
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=current_start_frame * pipeline.frame_seq_length,
)
# Decode to pixels and send frames immediately
print(f"🎨 Decoding block {idx+1} to pixels...")
decode_start = time.time()
if args.trt:
all_current_pixels = []
for i in range(denoised_pred.shape[1]):
is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
torch.tensor(0.0).cuda().half()
outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
# outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
current_pixels, vae_cache = outputs[0], outputs[1:]
print(current_pixels.max(), current_pixels.min())
all_current_pixels.append(current_pixels.clone())
pixels = torch.cat(all_current_pixels, dim=1)
if idx == 0:
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
else:
if current_use_taehv:
if vae_cache is None:
vae_cache = denoised_pred
else:
denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
vae_cache = denoised_pred[:, -3:, :, :, :]
pixels = current_vae_decoder.decode(denoised_pred)
print(f"denoised_pred shape: {denoised_pred.shape}")
print(f"pixels shape: {pixels.shape}")
if idx == 0:
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
else:
pixels = pixels[:, 12:, :, :, :]
else:
pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
if idx == 0:
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
decode_time = time.time() - decode_start
print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
# Queue frames for non-blocking sending
block_frames = pixels.shape[1]
print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
queue_start = time.time()
for frame_idx in range(block_frames):
if not generation_active or stop_event.is_set():
break
frame_tensor = pixels[0, frame_idx].cpu()
# Queue frame data in non-blocking way
frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
total_frames_sent += 1
queue_time = time.time() - queue_start
block_time = time.time() - block_start_time
print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
current_start_frame += current_num_frames
generation_time = time.time() - generation_start_time
print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
# Wait for all frames to be sent before completing
emit_progress('Waiting for all frames to be sent...', 97)
print("⏳ Waiting for all frames to be sent...")
frame_send_queue.join() # Wait for all queued frames to be processed
print("✅ All frames sent successfully!")
generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
# Final progress update
emit_progress('Generation complete!', 100)
try:
socketio.emit('generation_complete', {
'message': 'Video generation completed!',
'total_frames': total_frames_sent,
'generation_time': f"{generation_time:.2f}s",
'job_id': job_id
})
except Exception as e:
print(f"❌ Failed to emit generation complete: {e}")
except Exception as e:
print(f"❌ Generation failed: {e}")
try:
socketio.emit('error', {
'message': f'Generation failed: {str(e)}',
'job_id': job_id
})
except Exception as e:
print(f"❌ Failed to emit error: {e}")
finally:
generation_active = False
stop_event.set()
# Clean up sender thread
try:
frame_send_queue.put(None)
except Exception as e:
print(f"❌ Failed to put None in frame_send_queue: {e}")
def generate_mp4_from_images(image_directory, output_video_path, fps=24):
"""
Generate an MP4 video from a directory of images ordered alphabetically.
:param image_directory: Path to the directory containing images.
:param output_video_path: Path where the output MP4 will be saved.
:param fps: Frames per second for the output video.
"""
global anim_name
# Construct the ffmpeg command
cmd = [
'ffmpeg',
'-framerate', str(fps),
'-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
output_video_path
]
try:
subprocess.run(cmd, check=True)
print(f"Video saved to {output_video_path}")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
def calculate_sha256(data):
# Convert data to bytes if it's not already
if isinstance(data, str):
data = data.encode()
# Calculate SHA-256 hash
sha256_hash = hashlib.sha256(data).hexdigest()
return sha256_hash
# Socket.IO event handlers
@socketio.on('connect')
def handle_connect():
print('Client connected')
emit('status', {'message': 'Connected to frontend-buffered demo server'})
@socketio.on('disconnect')
def handle_disconnect():
print('Client disconnected')
@socketio.on('start_generation')
def handle_start_generation(data):
global generation_active, frame_number, anim_name, frame_rate
frame_number = 0
if generation_active:
emit('error', {'message': 'Generation already in progress'})
return
prompt = data.get('prompt', '')
seed = data.get('seed', -1)
if seed==-1:
seed = random.randint(0, 2**32)
# Extract words up to the first punctuation or newline
words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
if not words_up_to_punctuation:
words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()
# Calculate SHA-256 hash of the entire prompt
sha256_hash = calculate_sha256(prompt)
# Create anim_name with the extracted words and first 10 characters of the hash
anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"
generation_active = True
generation_start_time = time.time()
enable_torch_compile = data.get('enable_torch_compile', False)
enable_fp8 = data.get('enable_fp8', False)
use_taehv = data.get('use_taehv', False)
frame_rate = data.get('fps', 6)
if not prompt:
emit('error', {'message': 'Prompt is required'})
return
# Start generation in background thread
socketio.start_background_task(generate_video_stream, prompt, seed,
enable_torch_compile, enable_fp8, use_taehv)
emit('status', {'message': 'Generation started - frames will be sent immediately'})
@socketio.on('stop_generation')
def handle_stop_generation():
global generation_active, stop_event, frame_send_queue
generation_active = False
stop_event.set()
# Signal sender thread to stop (will be processed after current frames)
try:
frame_send_queue.put(None)
except Exception as e:
print(f"❌ Failed to put None in frame_send_queue: {e}")
emit('status', {'message': 'Generation stopped'})
# Web routes
@app.route('/')
def index():
return render_template('demo.html')
@app.route('/api/status')
def api_status():
return jsonify({
'generation_active': generation_active,
'free_vram_gb': get_cuda_free_memory_gb(gpu),
'fp8_applied': fp8_applied,
'torch_compile_applied': torch_compile_applied,
'current_use_taehv': current_use_taehv
})
if __name__ == '__main__':
print(f"🚀 Starting demo on http://{args.host}:{args.port}")
socketio.run(app, host=args.host, port=args.port, debug=False)
================================================
FILE: demo_utils/constant.py
================================================
import torch
ZERO_VAE_CACHE = [
torch.zeros(1, 16, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 192, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832)
]
feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names
================================================
FILE: demo_utils/memory.py
================================================
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel
import torch
cpu = torch.device('cpu')
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
gpu_complete_modules = []
class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, **kwargs):
original_class = module.__class__
module.__dict__['forge_backup_original_class'] = original_class
def hacked_get_attr(self, name: str):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
p = _parameters[name]
if p is None:
return None
if p.__class__ == torch.nn.Parameter:
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
else:
return p.to(**kwargs)
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name].to(**kwargs)
return super(original_class, self).__getattr__(name)
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
'__getattr__': hacked_get_attr,
})
return
@staticmethod
def _uninstall_module(module: torch.nn.Module):
if 'forge_backup_original_class' in module.__dict__:
module.__class__ = module.__dict__.pop('forge_backup_original_class')
return
@staticmethod
def install_model(model: torch.nn.Module, **kwargs):
for m in model.modules():
DynamicSwapInstaller._install_module(m, **kwargs)
return
@staticmethod
def uninstall_model(model: torch.nn.Module):
for m in model.modules():
DynamicSwapInstaller._uninstall_module(m)
return
def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
if hasattr(model, 'scale_shift_table'):
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
return
for k, p in model.named_modules():
if hasattr(p, 'weight'):
p.to(target_device)
return
def get_cuda_free_memory_gb(device=None):
if device is None:
device = gpu
memory_stats = torch.cuda.memory_stats(device)
bytes_active = memory_stats['active_bytes.all.current']
bytes_reserved = memory_stats['reserved_bytes.all.current']
bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
bytes_inactive_reserved = bytes_reserved - bytes_active
bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
return bytes_total_available / (1024 ** 3)
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
for m in model.modules():
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
torch.cuda.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=target_device)
model.to(device=target_device)
torch.cuda.empty_cache()
return
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
for m in model.modules():
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
torch.cuda.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=cpu)
model.to(device=cpu)
torch.cuda.empty_cache()
return
def unload_complete_models(*args):
for m in gpu_complete_modules + list(args):
m.to(device=cpu)
print(f'Unloaded {m.__class__.__name__} as complete.')
gpu_complete_modules.clear()
torch.cuda.empty_cache()
return
def load_model_as_complete(model, target_device, unload=True):
if unload:
unload_complete_models()
model.to(device=target_device)
print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
gpu_complete_modules.append(model)
return
================================================
FILE: demo_utils/taehv.py
================================================
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N * T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(
checkpoint_path, map_location="cpu", weights_only=True)))
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if sd[key].shape[0] > new_sd[key].shape[0]:
# take the last-timestep output channels
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=False):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
# return x[:, self.frames_to_trim:]
return x
def forward(self, x):
return self.c(x)
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import os
import sys
import cv2 # no highly esteemed deed is commemorated here
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, 'writer'):
self.writer.release()
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
print(
f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None]
vid_dev = video.to(dev, dtype).div_(255.0)
# convert to device tensor
if video.numel() < 100_000_000:
print(f" {video_path} seems small enough, will process all frames in parallel")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc)
print(f" Decoded {video_path} -> {vid_dec.shape}")
else:
print(f" {video_path} seems large, will process each frame sequentially")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev, parallel=False)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
video_out = VideoTensorWriter(
video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
print(f" Saved to {video_out_path}")
if __name__ == "__main__":
main()
================================================
FILE: demo_utils/utils.py
================================================
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel
import os
import cv2
import json
import random
import glob
import torch
import einops
import numpy as np
import datetime
import torchvision
from PIL import Image
def min_resize(x, m):
if x.shape[0] < x.shape[1]:
s0 = m
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
else:
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
s1 = m
new_max = max(s1, s0)
raw_max = max(x.shape[0], x.shape[1])
if new_max < raw_max:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
return y
def d_resize(x, y):
H, W, C = y.shape
new_min = min(H, W)
raw_min = min(x.shape[0], x.shape[1])
if new_min < raw_min:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (W, H), interpolation=interpolation)
return y
def resize_and_center_crop(image, target_width, target_height):
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return np.array(cropped_image)
def resize_and_center_crop_pytorch(image, target_width, target_height):
B, C, H, W = image.shape
if H == target_height and W == target_width:
return image
scale_factor = max(target_width / W, target_height / H)
resized_width = int(round(W * scale_factor))
resized_height = int(round(H * scale_factor))
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
top = (resized_height - target_height) // 2
left = (resized_width - target_width) // 2
cropped = resized[:, :, top:top + target_height, left:left + target_width]
return cropped
def resize_without_crop(image, target_width, target_height):
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
return np.array(resized_image)
def just_crop(image, w, h):
if h == image.shape[0] and w == image.shape[1]:
return image
original_height, original_width = image.shape[:2]
k = min(original_height / h, original_width / w)
new_width = int(round(w * k))
new_height = int(round(h * k))
x_start = (original_width - new_width) // 2
y_start = (original_height - new_height) // 2
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
return cropped_image
def write_to_json(data, file_path):
temp_file_path = file_path + ".tmp"
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
json.dump(data, temp_file, indent=4)
os.replace(temp_file_path, file_path)
return
def read_from_json(file_path):
with open(file_path, 'rt', encoding='utf-8') as file:
data = json.load(file)
return data
def get_active_parameters(m):
return {k: v for k, v in m.named_parameters() if v.requires_grad}
def cast_training_params(m, dtype=torch.float32):
result = {}
for n, param in m.named_parameters():
if param.requires_grad:
param.data = param.to(dtype)
result[n] = param
return result
def separate_lora_AB(parameters, B_patterns=None):
parameters_normal = {}
parameters_B = {}
if B_patterns is None:
B_patterns = ['.lora_B.', '__zero__']
for k, v in parameters.items():
if any(B_pattern in k for B_pattern in B_patterns):
parameters_B[k] = v
else:
parameters_normal[k] = v
return parameters_normal, parameters_B
def set_attr_recursive(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
return
def print_tensor_list_size(tensors):
total_size = 0
total_elements = 0
if isinstance(tensors, dict):
tensors = tensors.values()
for tensor in tensors:
total_size += tensor.nelement() * tensor.element_size()
total_elements += tensor.nelement()
total_size_MB = total_size / (1024 ** 2)
total_elements_B = total_elements / 1e9
print(f"Total number of tensors: {len(tensors)}")
print(f"Total size of tensors: {total_size_MB:.2f} MB")
print(f"Total number of parameters: {total_elements_B:.3f} billion")
return
@torch.no_grad()
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
batch_size = a.size(0)
if b is None:
b = torch.zeros_like(a)
if mask_a is None:
mask_a = torch.rand(batch_size) < probability_a
mask_a = mask_a.to(a.device)
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
result = torch.where(mask_a, a, b)
return result
@torch.no_grad()
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
@torch.no_grad()
def supress_lower_channels(m, k, alpha=0.01):
data = m.weight.data.clone()
assert int(data.shape[1]) >= k
data[:, :k] = data[:, :k] * alpha
m.weight.data = data.contiguous().clone()
return m
def freeze_module(m):
if not hasattr(m, '_forward_inside_frozen_module'):
m._forward_inside_frozen_module = m.forward
m.requires_grad_(False)
m.forward = torch.no_grad()(m.forward)
return m
def get_latest_safetensors(folder_path):
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
if not safetensors_files:
raise ValueError('No file to resume!')
latest_file = max(safetensors_files, key=os.path.getmtime)
latest_file = os.path.abspath(os.path.realpath(latest_file))
return latest_file
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
tags = tags_str.split(', ')
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
prompt = ', '.join(tags)
return prompt
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
if round_to_int:
numbers = np.round(numbers).astype(int)
return numbers.tolist()
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
edges = np.linspace(0, 1, n + 1)
points = np.random.uniform(edges[:-1], edges[1:])
numbers = inclusive + (exclusive - inclusive) * points
if round_to_int:
numbers = np.round(numbers).astype(int)
return numbers.tolist()
def soft_append_bcthw(history, current, overlap=0):
if overlap <= 0:
return torch.cat([history, current], dim=2)
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
return output.to(history)
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
b, c, t, h, w = x.shape
per_row = b
for p in [6, 5, 4, 3, 2]:
if b % p == 0:
per_row = p
break
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
return x
def save_bcthw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def save_bchw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c h w -> c h (b w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def add_tensors_with_padding(tensor1, tensor2):
if tensor1.shape == tensor2.shape:
return tensor1 + tensor2
shape1 = tensor1.shape
shape2 = tensor2.shape
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
padded_tensor1 = torch.zeros(new_shape)
padded_tensor2 = torch.zeros(new_shape)
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
result = padded_tensor1 + padded_tensor2
return result
def print_free_mem():
torch.cuda.empty_cache()
free_mem, total_mem = torch.cuda.mem_get_info(0)
free_mem_mb = free_mem / (1024 ** 2)
total_mem_mb = total_mem / (1024 ** 2)
print(f"Free memory: {free_mem_mb:.2f} MB")
print(f"Total memory: {total_mem_mb:.2f} MB")
return
def print_gpu_parameters(device, state_dict, log_count=1):
summary = {"device": device, "keys_count": len(state_dict)}
logged_params = {}
for i, (key, tensor) in enumerate(state_dict.items()):
if i >= log_count:
break
logged_params[key] = tensor.flatten()[:3].tolist()
summary["params"] = logged_params
print(str(summary))
return
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
from PIL import Image, ImageDraw, ImageFont
txt = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype(font_path, size=size)
if text == '':
return np.array(txt)
# Split text into lines that fit within the image width
lines = []
words = text.split()
current_line = words[0]
for word in words[1:]:
line_with_word = f"{current_line} {word}"
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
current_line = line_with_word
else:
lines.append(current_line)
current_line = word
lines.append(current_line)
# Draw the text line by line
y = 0
line_height = draw.textbbox((0, 0), "A", font=font)[3]
for line in lines:
if y + line_height > height:
break # stop drawing if the next line will be outside the image
draw.text((0, y), line, fill="black", font=font)
y += line_height
return np.array(txt)
def blue_mark(x):
x = x.copy()
c = x[:, :, 2]
b = cv2.blur(c, (9, 9))
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
return x
def green_mark(x):
x = x.copy()
x[:, :, 2] = -1
x[:, :, 0] = -1
return x
def frame_mark(x):
x = x.copy()
x[:64] = -1
x[-64:] = -1
x[:, :8] = 1
x[:, -8:] = 1
return x
@torch.inference_mode()
def pytorch2numpy(imgs):
results = []
for x in imgs:
y = x.movedim(0, -1)
y = y * 127.5 + 127.5
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
results.append(y)
return results
@torch.inference_mode()
def numpy2pytorch(imgs):
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
h = h.movedim(-1, 1)
return h
@torch.no_grad()
def duplicate_prefix_to_suffix(x, count, zero_out=False):
if zero_out:
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
else:
return torch.cat([x, x[:count]], dim=0)
def weighted_mse(a, b, weight):
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
x = (x - x_min) / (x_max - x_min)
x = max(0.0, min(x, 1.0))
x = x ** sigma
return y_min + x * (y_max - y_min)
def expand_to_dims(x, target_dims):
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
if tensor is None:
return None
first_dim = tensor.shape[0]
if first_dim == batch_size:
return tensor
if batch_size % first_dim != 0:
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
repeat_times = batch_size // first_dim
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
def dim5(x):
return expand_to_dims(x, 5)
def dim4(x):
return expand_to_dims(x, 4)
def dim3(x):
return expand_to_dims(x, 3)
def crop_or_pad_yield_mask(x, length):
B, F, C = x.shape
device = x.device
dtype = x.dtype
if F < length:
y = torch.zeros((B, length, C), dtype=dtype, device=device)
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
y[:, :F, :] = x
mask[:, :F] = True
return y, mask
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
def extend_dim(x, dim, minimal_length, zero_pad=False):
original_length = int(x.shape[dim])
if original_length >= minimal_length:
return x
if zero_pad:
padding_shape = list(x.shape)
padding_shape[dim] = minimal_length - original_length
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
else:
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
last_element = x[idx]
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
return torch.cat([x, padding], dim=dim)
def lazy_positional_encoding(t, repeats=None):
if not isinstance(t, list):
t = [t]
from diffusers.models.embeddings import get_timestep_embedding
te = torch.tensor(t)
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
if repeats is None:
return te
te = te[:, None, :].expand(-1, repeats, -1)
return te
def state_dict_offset_merge(A, B, C=None):
result = {}
keys = A.keys()
for key in keys:
A_value = A[key]
B_value = B[key].to(A_value)
if C is None:
result[key] = A_value + B_value
else:
C_value = C[key].to(A_value)
result[key] = A_value + B_value - C_value
return result
def state_dict_weighted_merge(state_dicts, weights):
if len(state_dicts) != len(weights):
raise ValueError("Number of state dictionaries must match number of weights")
if not state_dicts:
return {}
total_weight = sum(weights)
if total_weight == 0:
raise ValueError("Sum of weights cannot be zero")
normalized_weights = [w / total_weight for w in weights]
keys = state_dicts[0].keys()
result = {}
for key in keys:
result[key] = state_dicts[0][key] * normalized_weights[0]
for i in range(1, len(state_dicts)):
state_dict_value = state_dicts[i][key].to(result[key])
result[key] += state_dict_value * normalized_weights[i]
return result
def group_files_by_folder(all_files):
grouped_files = {}
for file in all_files:
folder_name = os.path.basename(os.path.dirname(file))
if folder_name not in grouped_files:
grouped_files[folder_name] = []
grouped_files[folder_name].append(file)
list_of_lists = list(grouped_files.values())
return list_of_lists
def generate_timestamp():
now = datetime.datetime.now()
timestamp = now.strftime('%y%m%d_%H%M%S')
milliseconds = f"{int(now.microsecond / 1000):03d}"
random_number = random.randint(0, 9999)
return f"{timestamp}_{milliseconds}_{random_number}"
def write_PIL_image_with_png_info(image, metadata, path):
from PIL.PngImagePlugin import PngInfo
png_info = PngInfo()
for key, value in metadata.items():
png_info.add_text(key, value)
image.save(path, "PNG", pnginfo=png_info)
return image
def torch_safe_save(content, path):
torch.save(content, path + '_tmp')
os.replace(path + '_tmp', path)
return path
def move_optimizer_to_device(optimizer, device):
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
================================================
FILE: demo_utils/vae.py
================================================
from typing import List
from einops import rearrange
import tensorrt as trt
import torch
import torch.nn as nn
from demo_utils.constant import ALL_INPUTS_NAMES, ZERO_VAE_CACHE
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, Upsample
CACHE_T = 2
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache_1, feat_cache_2):
h = self.shortcut(x)
feat_cache = feat_cache_1
out_feat_cache = []
for layer in self.residual:
if isinstance(layer, CausalConv3d):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache)
out_feat_cache.append(cache_x)
feat_cache = feat_cache_2
else:
x = layer(x)
return x + h, *out_feat_cache
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, is_first_frame, feat_cache):
if self.mode == 'upsample3d':
b, c, t, h, w = x.size()
# x, out_feat_cache = torch.cond(
# is_first_frame,
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
# lambda: self.temporal_conv(x, feat_cache),
# )
# x, out_feat_cache = torch.cond(
# is_first_frame,
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
# lambda: self.temporal_conv(x, feat_cache),
# )
x, out_feat_cache = self.temporal_conv(x, is_first_frame, feat_cache)
out_feat_cache = torch.cond(
is_first_frame,
lambda: feat_cache.clone().contiguous(),
lambda: out_feat_cache.clone().contiguous(),
)
# if is_first_frame:
# x = torch.cat([torch.zeros_like(x), x], dim=2)
# out_feat_cache = feat_cache.clone()
# else:
# x, out_feat_cache = self.temporal_conv(x, feat_cache)
else:
out_feat_cache = None
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x, out_feat_cache
def temporal_conv(self, x, is_first_frame, feat_cache):
b, c, t, h, w = x.size()
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache is not None:
cache_x = torch.cat([
torch.zeros_like(cache_x),
cache_x
], dim=2)
x = torch.cond(
is_first_frame,
lambda: torch.cat([torch.zeros_like(x), x], dim=1).contiguous(),
lambda: self.time_conv(x, feat_cache).contiguous(),
)
# x = self.time_conv(x, feat_cache)
out_feat_cache = cache_x
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
return x.contiguous(), out_feat_cache.contiguous()
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class VAEDecoderWrapperSingle(nn.Module):
def __init__(self):
super().__init__()
self.decoder = VAEDecoder3d()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
self.z_dim = 16
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
def forward(
self,
z: torch.Tensor,
is_first_frame: torch.Tensor,
*feat_cache: List[torch.Tensor]
):
# from [batch_size, num_frames, num_channels, height, width]
# to [batch_size, num_channels, num_frames, height, width]
z = z.permute(0, 2, 1, 3, 4)
assert z.shape[2] == 1
feat_cache = list(feat_cache)
is_first_frame = is_first_frame.bool()
device, dtype = z.device, z.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
x = self.conv2(z)
out, feat_cache = self.decoder(x, is_first_frame, feat_cache=feat_cache)
out = out.clamp_(-1, 1)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
out = out.permute(0, 2, 1, 3, 4)
return out, feat_cache
class VAEDecoder3d(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.cache_t = 2
self.decoder_conv_num = 32
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(
self,
x: torch.Tensor,
is_first_frame: torch.Tensor,
feat_cache: List[torch.Tensor]
):
idx = 0
out_feat_cache = []
# conv1
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
out_feat_cache.append(cache_x)
idx += 1
# middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
idx += 2
out_feat_cache.append(out_feat_cache_1)
out_feat_cache.append(out_feat_cache_2)
else:
x = layer(x)
# upsamples
for layer in self.upsamples:
if isinstance(layer, Resample):
x, cache_x = layer(x, is_first_frame, feat_cache[idx])
if cache_x is not None:
out_feat_cache.append(cache_x)
idx += 1
else:
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
idx += 2
out_feat_cache.append(out_feat_cache_1)
out_feat_cache.append(out_feat_cache_2)
# head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
out_feat_cache.append(cache_x)
idx += 1
else:
x = layer(x)
return x, out_feat_cache
class VAETRTWrapper():
def __init__(self):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with open("checkpoints/vae_decoder_int8.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
self.engine: trt.ICudaEngine = rt.deserialize_cuda_engine(f.read())
self.context: trt.IExecutionContext = self.engine.create_execution_context()
self.stream = torch.cuda.current_stream().cuda_stream
# ──────────────────────────────
# 2️⃣ Feed the engine with tensors
# (name-based API in TRT ≥10)
# ──────────────────────────────
self.dtype_map = {
trt.float32: torch.float32,
trt.float16: torch.float16,
trt.int8: torch.int8,
trt.int32: torch.int32,
}
test_input = torch.zeros(1, 16, 1, 60, 104).cuda().half()
is_first_frame = torch.tensor(1.0).cuda().half()
test_cache_inputs = [c.cuda().half() for c in ZERO_VAE_CACHE]
test_inputs = [test_input, is_first_frame] + test_cache_inputs
# keep references so buffers stay alive
self.device_buffers, self.outputs = {}, []
# ---- inputs ----
for i, name in enumerate(ALL_INPUTS_NAMES):
tensor, scale = test_inputs[i], 1 / 127
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
# dynamic shapes
if -1 in self.engine.get_tensor_shape(name):
# new API :contentReference[oaicite:0]{index=0}
self.context.set_input_shape(name, tuple(tensor.shape))
# replaces bindings[] :contentReference[oaicite:1]{index=1}
self.context.set_tensor_address(name, int(tensor.data_ptr()))
self.device_buffers[name] = tensor # keep pointer alive
# ---- (after all input shapes are known) infer output shapes ----
# propagates shapes :contentReference[oaicite:2]{index=2}
self.context.infer_shapes()
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
# replaces binding_is_input :contentReference[oaicite:3]{index=3}
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = tuple(self.context.get_tensor_shape(name))
dtype = self.dtype_map[self.engine.get_tensor_dtype(name)]
out = torch.empty(shape, dtype=dtype, device="cuda").contiguous()
self.context.set_tensor_address(name, int(out.data_ptr()))
self.outputs.append(out)
self.device_buffers[name] = out
# helper to quant-convert on the fly
def quantize_if_needed(self, t, expected_dtype, scale):
if expected_dtype == trt.int8 and t.dtype != torch.int8:
t = torch.clamp((t / scale).round(), -128, 127).to(torch.int8).contiguous()
return t # keep pointer alive
def forward(self, *test_inputs):
for i, name in enumerate(ALL_INPUTS_NAMES):
tensor, scale = test_inputs[i], 1 / 127
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
self.context.set_tensor_address(name, int(tensor.data_ptr()))
self.device_buffers[name] = tensor
self.context.execute_async_v3(stream_handle=self.stream)
torch.cuda.current_stream().synchronize()
return self.outputs
================================================
FILE: demo_utils/vae_block3.py
================================================
from typing import List
from einops import rearrange
import torch
import torch.nn as nn
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, ResidualBlock, Upsample
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
self.cache_t = 2
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class VAEDecoderWrapper(nn.Module):
def __init__(self):
super().__init__()
self.decoder = VAEDecoder3d()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
self.z_dim = 16
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
def forward(
self,
z: torch.Tensor,
*feat_cache: List[torch.Tensor]
):
# from [batch_size, num_frames, num_channels, height, width]
# to [batch_size, num_channels, num_frames, height, width]
z = z.permute(0, 2, 1, 3, 4)
feat_cache = list(feat_cache)
print("Length of feat_cache: ", len(feat_cache))
device, dtype = z.device, z.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
if i == 0:
out, feat_cache = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_cache)
else:
out_, feat_cache = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_cache)
out = torch.cat([out, out_], 2)
out = out.float().clamp_(-1, 1)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
out = out.permute(0, 2, 1, 3, 4)
return out, feat_cache
class VAEDecoder3d(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.cache_t = 2
self.decoder_conv_num = 32
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(
self,
x: torch.Tensor,
feat_cache: List[torch.Tensor]
):
feat_idx = [0]
# conv1
idx = feat_idx[0]
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
# middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# upsamples
for layer in self.upsamples:
x = layer(x, feat_cache, feat_idx)
# head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x, feat_cache
================================================
FILE: demo_utils/vae_torch2trt.py
================================================
# ---- INT8 (optional) ----
from demo_utils.vae import (
VAEDecoderWrapperSingle, # main nn.Module
ZERO_VAE_CACHE # helper constants shipped with your code base
)
import pycuda.driver as cuda # ← add
import pycuda.autoinit # noqa
import sys
from pathlib import Path
import torch
import tensorrt as trt
from utils.dataset import ShardingLMDBDataset
data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=0
)
# ─────────────────────────────────────────────────────────
# 1️⃣ Bring the PyTorch model into scope
# (all code you pasted lives in `vae_decoder.py`)
# ─────────────────────────────────────────────────────────
# --- dummy tensors (exact shapes you posted) ---
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
dummy_cache_input = [
torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
for s in ZERO_VAE_CACHE # keep exactly the same ordering
]
inputs = [dummy_input, is_first_frame, *dummy_cache_input]
# ─────────────────────────────────────────────────────────
# 2️⃣ Export → ONNX
# ─────────────────────────────────────────────────────────
model = VAEDecoderWrapperSingle().half().cuda().eval()
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
if 'decoder.' in key or 'conv2' in key:
decoder_state_dict[key] = value
model.load_state_dict(decoder_state_dict)
model = model.half().cuda().eval() # only batch dim dynamic
onnx_path = Path("vae_decoder.onnx")
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
all_inputs_names = ["z", "use_cache"] + feat_names
with torch.inference_mode():
torch.onnx.export(
model,
tuple(inputs), # must be a tuple
onnx_path.as_posix(),
input_names=all_inputs_names,
output_names=["rgb_out", "cache_out"],
opset_version=17,
do_constant_folding=True,
dynamo=True
)
print(f"✅ ONNX graph saved to {onnx_path.resolve()}")
# (Optional) quick sanity-check with ONNX-Runtime
try:
import onnxruntime as ort
sess = ort.InferenceSession(onnx_path.as_posix(),
providers=["CUDAExecutionProvider"])
ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
_ = sess.run(None, ort_inputs)
print("✅ ONNX graph is executable")
except Exception as e:
print("⚠️ ONNX check failed:", e)
# ─────────────────────────────────────────────────────────
# 3️⃣ Build the TensorRT engine
# ─────────────────────────────────────────────────────────
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
sys.exit("❌ ONNX → TRT parsing failed")
config = builder.create_builder_config()
def set_workspace(config, bytes_):
"""Version-agnostic workspace limit."""
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
config.max_workspace_size = bytes_
else: # TRT 10+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)
# …
config = builder.create_builder_config()
set_workspace(config, 4 << 30) # 4 GB
# 4 GB
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# ---- INT8 (optional) ----
# provide a calibrator if you need an INT8 engine; comment this
# block if you only care about FP16.
# ─────────────────────────────────────────────────────────
# helper: version-agnostic workspace limit
# ─────────────────────────────────────────────────────────
def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
"""
TRT < 10.x → config.max_workspace_size
TRT ≥ 10.x → config.set_memory_pool_limit(...)
"""
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
config.max_workspace_size = bytes_
else: # TRT 10+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
bytes_)
# ─────────────────────────────────────────────────────────
# (optional) INT-8 calibrator
# ─────────────────────────────────────────────────────────
# ‼ Only keep this block if you really need INT-8 ‼ # gracefully skip if PyCUDA not present
class VAECalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, loader, cache="calibration.cache", max_batches=10):
super().__init__()
self.loader = iter(loader)
self.batch_size = loader.batch_size or 1
self.max_batches = max_batches
self.count = 0
self.cache_file = cache
self.stream = cuda.Stream()
self.dev_ptrs = {}
# --- TRT 10 needs BOTH spellings ---
def get_batch_size(self):
return self.batch_size
def getBatchSize(self):
return self.batch_size
def get_batch(self, names):
if self.count >= self.max_batches:
return None
# Randomly sample a number from 1 to 10
import random
vae_idx = random.randint(0, 10)
data = next(self.loader)
latent = data['ode_latent'][0][:, :1]
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
feat_cache = ZERO_VAE_CACHE
for i in range(vae_idx):
inputs = [latent, is_first_frame, *feat_cache]
with torch.inference_mode():
outputs = model(*inputs)
latent = data['ode_latent'][0][:, i + 1:i + 2]
is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
feat_cache = outputs[1:]
# -------- ensure context is current --------
z_np = latent.cpu().numpy().astype('float32')
ptrs = [] # list[int] – one entry per name
for name in names: # <-- match TRT's binding order
if name == "z":
arr = z_np
elif name == "use_cache":
arr = is_first_frame.cpu().numpy().astype('float32')
else:
idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17
arr = feat_cache[idx].cpu().numpy().astype('float32')
if name not in self.dev_ptrs:
self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)
cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
ptrs.append(int(self.dev_ptrs[name])) # ***int() is required***
self.stream.synchronize()
self.count += 1
print(f"Calibration batch {self.count}/{self.max_batches}")
return ptrs
# --- calibration-cache helpers (both spellings) ---
def read_calibration_cache(self):
try:
with open(self.cache_file, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def readCalibrationCache(self):
return self.read_calibration_cache()
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)
def writeCalibrationCache(self, cache):
self.write_calibration_cache(cache)
# ─────────────────────────────────────────────────────────
# Builder-config + optimisation profile
# ─────────────────────────────────────────────────────────
config = builder.create_builder_config()
set_workspace(config, 4 << 30) # 4 GB
# ► enable FP16 if possible
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# ► enable INT-8 (delete this block if you don’t need it)
if cuda is not None:
config.set_flag(trt.BuilderFlag.INT8)
# supply any representative batch you like – here we reuse the latent z
calib = VAECalibrator(dataloader)
# TRT-10 renamed the setter:
if hasattr(config, "set_int8_calibrator"): # TRT 10+
config.set_int8_calibrator(calib)
else: # TRT ≤ 9
config.int8_calibrator = calib
# ---- optimisation profile ----
profile = builder.create_optimization_profile()
profile.set_shape(all_inputs_names[0], # latent z
min=(1, 1, 16, 60, 104),
opt=(1, 1, 16, 60, 104),
max=(1, 1, 16, 60, 104))
profile.set_shape("use_cache", # scalar flag
min=(1,), opt=(1,), max=(1,))
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)
config.add_optimization_profile(profile)
# ─────────────────────────────────────────────────────────
# Build the engine (API changed in TRT-10)
# ─────────────────────────────────────────────────────────
print("⚙️ Building engine … (can take a minute)")
if hasattr(builder, "build_serialized_network"): # TRT 10+
serialized_engine = builder.build_serialized_network(network, config)
assert serialized_engine is not None, "build_serialized_network() failed"
plan_path = Path("checkpoints/vae_decoder_int8.trt")
plan_path.write_bytes(serialized_engine)
engine_bytes = serialized_engine # keep for smoke-test
else: # TRT ≤ 9
engine = builder.build_engine(network, config)
assert engine is not None, "build_engine() returned None"
plan_path = Path("checkpoints/vae_decoder_int8.trt")
plan_path.write_bytes(engine.serialize())
engine_bytes = engine.serialize()
print(f"✅ TensorRT engine written to {plan_path.resolve()}")
# ─────────────────────────────────────────────────────────
# 4️⃣ Quick smoke-test with the brand-new engine
# ─────────────────────────────────────────────────────────
with trt.Runtime(TRT_LOGGER) as rt:
engine = rt.deserialize_cuda_engine(engine_bytes)
context = engine.create_execution_context()
stream = torch.cuda.current_stream().cuda_stream
# pre-allocate device buffers once
device_buffers, outputs = {}, []
dtype_map = {trt.float32: torch.float32,
trt.float16: torch.float16,
trt.int8: torch.int8,
trt.int32: torch.int32}
for name, tensor in zip(all_inputs_names, inputs):
if -1 in engine.get_tensor_shape(name): # dynamic input
context.set_input_shape(name, tensor.shape)
context.set_tensor_address(name, int(tensor.data_ptr()))
device_buffers[name] = tensor
context.infer_shapes() # propagate ⇢ outputs
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = tuple(context.get_tensor_shape(name))
dtype = dtype_map[engine.get_tensor_dtype(name)]
out = torch.empty(shape, dtype=dtype, device="cuda")
context.set_tensor_address(name, int(out.data_ptr()))
outputs.append(out)
print(f"output {name} shape: {shape}")
context.execute_async_v3(stream_handle=stream)
torch.cuda.current_stream().synchronize()
print("✅ TRT execution OK – first output shape:", outputs[0].shape)
================================================
FILE: get_causal_ode_data_chunkwise.py
================================================
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
from utils.scheduler import FlowMatchScheduler
from utils.distributed import launch_distributed_job
import torch.distributed as dist
from tqdm import tqdm
import argparse
import torch
import math
import os
from utils.dataset import LatentLMDBDataset
def init_model(device):
model = WanDiffusionWrapper(is_causal=True).to(device).to(torch.float32)
model.model.num_frame_per_block = 3 # !!
encoder = WanTextEncoder().to(device).to(torch.float32)
scheduler = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
scheduler.sigmas = scheduler.sigmas.to(device)
sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
unconditional_dict = encoder(
text_prompts=[sample_neg_prompt]
)
return model, encoder, scheduler, unconditional_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--output_folder", type=str)
parser.add_argument("--rawdata_path", type=str)
parser.add_argument("--generator_ckpt", type=str)
parser.add_argument("--guidance_scale", type=float, default=6.0)
args = parser.parse_args()
launch_distributed_job()
global_rank = dist.get_rank()
device = torch.cuda.current_device()
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model, encoder, scheduler, unconditional_dict = init_model(device=device)
state_dict = torch.load(args.generator_ckpt, map_location="cpu")
gen_sd = state_dict["generator"]
fixed = {}
for k, v in gen_sd.items():
if k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "", 1)
if k.startswith("model."):
k = k.replace("model.", "", 1)
fixed[k] = v
state_dict = fixed
model.model.load_state_dict(
state_dict, strict=True
)
dataset = LatentLMDBDataset(args.rawdata_path)
if global_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
total_steps = int(math.ceil(len(dataset) / dist.get_world_size()))
for index in tqdm(
range(total_steps), disable=(dist.get_rank() != 0),
):
prompt_index = index * dist.get_world_size() + dist.get_rank()
if prompt_index >= len(dataset):
continue
sample = dataset[prompt_index]
prompt = sample["prompts"]
clean_latent = sample["clean_latent"].to(device).unsqueeze(0)
conditional_dict = encoder(
text_prompts=prompt
)
latents = torch.randn(
[1, 21, 16, 60, 104], dtype=torch.float32, device=device
)
noisy_input = []
for progress_id, t in enumerate(tqdm(scheduler.timesteps, disable=(dist.get_rank() != 0))):
timestep = t * \
torch.ones([1, 21], device=device, dtype=torch.float32)
noisy_input.append(latents)
f_cond, x0_pred_cond = model(
latents, conditional_dict, timestep, clean_x = clean_latent
)
f_uncond, x0_pred_uncond = model(
latents, unconditional_dict, timestep, clean_x = clean_latent
)
flow_pred = f_uncond + args.guidance_scale * (
f_cond - f_uncond
)
latents = scheduler.step(
flow_pred.flatten(0, 1),
timestep.flatten(0, 1),
latents.flatten(0, 1)
).unflatten(dim=0, sizes=flow_pred.shape[:2])
noisy_input.append(latents)
noisy_input.append(clean_latent)
noisy_inputs = torch.stack(noisy_input, dim=1)
noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -2, -1]]
stored_data = noisy_inputs
torch.save(
{prompt: stored_data.cpu().detach()},
os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
)
dist.barrier()
if __name__ == "__main__":
main()
================================================
FILE: get_causal_ode_data_framewise.py
================================================
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
from utils.scheduler import FlowMatchScheduler
from utils.distributed import launch_distributed_job
import torch.distributed as dist
from tqdm import tqdm
import argparse
import torch
import math
import os
from utils.dataset import LatentLMDBDataset
def init_model(device):
model = WanDiffusionWrapper(is_causal=True).to(device).to(torch.float32)
model.model.num_frame_per_block = 1 # !!
encoder = WanTextEncoder().to(device).to(torch.float32)
scheduler = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
scheduler.sigmas = scheduler.sigmas.to(device)
sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
unconditional_dict = encoder(
text_prompts=[sample_neg_prompt]
)
return model, encoder, scheduler, unconditional_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--output_folder", type=str)
parser.add_argument("--rawdata_path", type=str)
parser.add_argument("--generator_ckpt", type=str)
parser.add_argument("--guidance_scale", type=float, default=6.0)
args = parser.parse_args()
launch_distributed_job()
global_rank = dist.get_rank()
device = torch.cuda.current_device()
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model, encoder, scheduler, unconditional_dict = init_model(device=device)
state_dict = torch.load(args.generator_ckpt, map_location="cpu")
gen_sd = state_dict["generator"]
fixed = {}
for k, v in gen_sd.items():
if k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "", 1)
if k.startswith("model."):
k = k.replace("model.", "", 1)
fixed[k] = v
state_dict = fixed
model.model.load_state_dict(
state_dict, strict=True
)
dataset = LatentLMDBDataset(args.rawdata_path)
if global_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
total_steps = int(math.ceil(len(dataset) / dist.get_world_size()))
for index in tqdm(
range(total_steps), disable=(dist.get_rank() != 0),
):
prompt_index = index * dist.get_world_size() + dist.get_rank()
if prompt_index >= len(dataset):
continue
sample = dataset[prompt_index]
prompt = sample["prompts"]
clean_latent = sample["clean_latent"].to(device).unsqueeze(0)
conditional_dict = encoder(
text_prompts=prompt
)
latents = torch.randn(
[1, 21, 16, 60, 104], dtype=torch.float32, device=device
)
noisy_input = []
for progress_id, t in enumerate(tqdm(scheduler.timesteps, disable=(dist.get_rank() != 0))):
timestep = t * \
torch.ones([1, 21], device=device, dtype=torch.float32)
noisy_input.append(latents)
f_cond, x0_pred_cond = model(
latents, conditional_dict, timestep, clean_x = clean_latent
)
f_uncond, x0_pred_uncond = model(
latents, unconditional_dict, timestep, clean_x = clean_latent
)
flow_pred = f_uncond + args.guidance_scale * (
f_cond - f_uncond
)
latents = scheduler.step(
flow_pred.flatten(0, 1),
timestep.flatten(0, 1),
latents.flatten(0, 1)
).unflatten(dim=0, sizes=flow_pred.shape[:2])
noisy_input.append(latents)
noisy_input.append(clean_latent)
noisy_inputs = torch.stack(noisy_input, dim=1)
noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -2, -1]]
stored_data = noisy_inputs
torch.save(
{prompt: stored_data.cpu().detach()},
os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
)
dist.barrier()
if __name__ == "__main__":
main()
================================================
FILE: get_causal_ode_data_kv_optimized.py
================================================
import argparse
import math
import os
import torch
import torch.distributed as dist
from tqdm import tqdm
from utils.dataset import LatentLMDBDataset
from utils.distributed import launch_distributed_job
from utils.ode_generation import (
CausalODETrajectoryGenerator,
merge_cfg_prompt_embeds,
)
from utils.scheduler import FlowMatchScheduler
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
DEFAULT_NEGATIVE_PROMPT = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
"最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
"杂乱的背景,三条腿,背景人很多,倒着走"
)
DEFAULT_NUM_INFERENCE_STEPS = 48
DEFAULT_SCHEDULER_SHIFT = 5.0
DEFAULT_TARGET_NUM_FRAMES = 21
DEFAULT_TRAJECTORY_INDICES = [0, 12, 24, 36, -2, -1]
def normalize_generator_state_dict(state_dict: dict) -> dict:
if "generator" in state_dict:
state_dict = state_dict["generator"]
elif "generator_ema" in state_dict:
state_dict = state_dict["generator_ema"]
fixed = {}
for k, v in state_dict.items():
if k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "", 1)
if k.startswith("model."):
k = k.replace("model.", "", 1)
fixed[k] = v
return fixed
def init_model(
device,
num_frame_per_block: int,
scheduler_shift: float,
num_inference_steps: int,
negative_prompt: str,
):
model = WanDiffusionWrapper(is_causal=True).to(device).to(torch.float32)
model.model.num_frame_per_block = num_frame_per_block
encoder = WanTextEncoder().to(device).to(torch.float32)
scheduler = FlowMatchScheduler(
shift=scheduler_shift,
sigma_min=0.0,
extra_one_step=True,
)
scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
denoising_strength=1.0,
)
scheduler.sigmas = scheduler.sigmas.to(device)
unconditional_dict = encoder(text_prompts=[negative_prompt])
return model, encoder, scheduler, unconditional_dict
def prepare_clean_latent(
sample: dict,
target_num_frames: int | None,
device,
) -> torch.Tensor:
clean_latent = sample["clean_latent"].to(device).unsqueeze(0)
if target_num_frames is None:
return clean_latent
if clean_latent.shape[1] < target_num_frames:
raise ValueError(
"clean_latent has fewer frames than requested: "
f"{clean_latent.shape[1]} < {target_num_frames}"
)
if clean_latent.shape[1] != target_num_frames:
clean_latent = clean_latent[:, :target_num_frames, ...]
return clean_latent.contiguous()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--output_folder", type=str, required=True)
parser.add_argument("--rawdata_path", type=str, required=True)
parser.add_argument("--generator_ckpt", type=str, required=True)
parser.add_argument("--num_frames_per_chunk", type=int, required=True)
parser.add_argument("--guidance_scale", type=float, default=6.0)
parser.add_argument(
"--generation_mode",
type=str,
default="full",
choices=["full", "blockwise_kv"],
)
args = parser.parse_args()
launch_distributed_job()
global_rank = dist.get_rank()
device = torch.cuda.current_device()
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model, encoder, scheduler, unconditional_dict = init_model(
device=device,
num_frame_per_block=args.num_frame_per_block,
scheduler_shift=DEFAULT_SCHEDULER_SHIFT,
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
)
state_dict = torch.load(args.generator_ckpt, map_location="cpu")
model.model.load_state_dict(
normalize_generator_state_dict(state_dict),
strict=True,
)
dataset = LatentLMDBDataset(
args.rawdata_path,
max_pair=int(1e8),
)
if global_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
trajectory_generator = CausalODETrajectoryGenerator(
model=model,
scheduler=scheduler,
num_frame_per_block=args.num_frame_per_block,
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
guidance_scale=args.guidance_scale,
)
total_steps = int(math.ceil(len(dataset) / dist.get_world_size()))
for index in tqdm(
range(total_steps), disable=(global_rank != 0),
):
prompt_index = index * dist.get_world_size() + global_rank
if prompt_index >= len(dataset):
continue
output_path = os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
sample = dataset[prompt_index]
prompt = sample["prompts"]
clean_latent = prepare_clean_latent(
sample=sample,
target_num_frames=DEFAULT_TARGET_NUM_FRAMES,
device=device,
)
conditional_dict = encoder(text_prompts=[prompt])
paired_conditional_dict = merge_cfg_prompt_embeds(
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict,
)
initial_noise = torch.randn_like(clean_latent, dtype=torch.float32)
stored_data = trajectory_generator.generate(
clean_latent=clean_latent,
paired_conditional_dict=paired_conditional_dict,
trajectory_indices=DEFAULT_TRAJECTORY_INDICES,
generation_mode=args.generation_mode,
initial_noise=initial_noise,
)
torch.save(
{prompt: stored_data.cpu().detach()},
output_path,
)
dist.barrier()
if __name__ == "__main__":
main()
================================================
FILE: inference.py
================================================
import argparse
import argparse
import torch
import os
from omegaconf import OmegaConf
from tqdm import tqdm
from torchvision import transforms
from torchvision.io import write_video
from einops import rearrange
import torch.distributed as dist
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import json
from pipeline import (
CausalDiffusionInferencePipeline,
CausalInferencePipeline,
)
from utils.dataset import TextDataset, TextImagePairDataset
from utils.misc import set_seed
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to the config file")
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--output_folder", type=str, help="Output folder")
parser.add_argument("--num_output_frames", type=int, default=21, help="Number of overlap frames between sliding windows")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
parser.add_argument("--report_timing", action="store_true",
help="Only tested on A800, for the Causal Forcing++ latency. Not make claims for other hardware like H100. For the result on H100, refer to the reported results in the Self Forcing paper.")
args = parser.parse_args()
# Initialize distributed inference
if "LOCAL_RANK" in os.environ:
dist.init_process_group(backend='nccl')
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
world_size = dist.get_world_size()
else:
device = torch.device("cuda")
local_rank = 0
world_size = 1
set_seed(args.seed)
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
low_memory = get_cuda_free_memory_gb(gpu) < 40
torch.set_grad_enabled(False)
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
# Initialize pipeline
if hasattr(config, 'denoising_step_list'):
# Few-step inference
pipeline = CausalInferencePipeline(config, device=device)
else:
# Multi-step diffusion inference
pipeline = CausalDiffusionInferencePipeline(config, device=device)
if args.checkpoint_path:
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
key = 'generator_ema' if args.use_ema else 'generator'
gen_sd = state_dict[key]
try:
pipeline.generator.load_state_dict(gen_sd)
except RuntimeError:
fixed = {}
for k, v in gen_sd.items():
if k.startswith("model._fsdp_wrapped_module."):
k = k.replace("model._fsdp_wrapped_module.", "model.", 1)
fixed[k] = v
pipeline.generator.load_state_dict(fixed, strict=False)
pipeline = pipeline.to(dtype=torch.bfloat16)
if low_memory:
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=gpu)
else:
pipeline.text_encoder.to(device=gpu)
pipeline.generator.to(device=gpu)
pipeline.vae.to(device=gpu)
# Create dataset
if args.i2v:
assert not dist.is_initialized(), "I2V does not support distributed inference yet"
transform = transforms.Compose([
transforms.Resize((480, 832)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = TextImagePairDataset(args.data_path, transform=transform)
else:
dataset = TextDataset(prompt_path=args.data_path)
num_prompts = len(dataset)
print(f"Number of prompts: {num_prompts}")
if args.report_timing and num_prompts < 2:
print(f"[WARN] --report_timing requires at least 2 prompts "
f"(got {num_prompts}); timing disabled.")
args.report_timing = False
if dist.is_initialized():
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
# Create output directory (only on main process to avoid race conditions)
if local_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
def encode(self, videos: torch.Tensor) -> torch.Tensor:
device, dtype = videos[0].device, videos[0].dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
output = [
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
for u in videos
]
output = torch.stack(output, dim=0)
return output
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
idx = batch_data['idx'].item()
if isinstance(batch_data, dict):
batch = batch_data
elif isinstance(batch_data, list):
batch = batch_data[0] # First (and only) item in the batch
all_video = []
num_generated_frames = 0 # Number of generated (latent) frames
if args.i2v:
assert config.num_frame_per_block == 1, "Current I2V only supports the frame-wise model."
# For image-to-video, batch contains image and caption
prompt = batch['prompts'][0] # Get caption from batch
output_path = os.path.join(args.output_folder, f'{prompt[:100]}.mp4')
if os.path.exists(output_path):
print('Video has been generated. Pass!')
continue
# Process the image
image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
# Encode the input image as the first latent
initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
prompts = [prompt]
sampled_noise = torch.randn(
[1, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
)
else:
# For text-to-video, batch is just the text prompt
prompt = batch['prompts'][0]
output_path = os.path.join(args.output_folder, f'{prompt[:100]}.mp4')
if os.path.exists(output_path):
print('Video has been generated. Pass!')
continue
extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
if extended_prompt is not None:
prompts = [extended_prompt]
else:
prompts = [prompt]
initial_latent = None
sampled_noise = torch.randn(
[1, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
)
sample_report_timing = args.report_timing and i >= 1
video, latents = pipeline.inference(
noise=sampled_noise,
text_prompts=prompts,
return_latents=True,
initial_latent=initial_latent,
report_timing=sample_report_timing,
)
if sample_report_timing:
latency = pipeline.first_chunk_time
elapsed = pipeline.last_generation_time
num_pixel_frames = video.shape[1]
fps = num_pixel_frames / elapsed if elapsed > 0 else float('inf')
print(f"[Sample {i}] {num_pixel_frames} frames, "
f"latency ↓ {latency:.2f}s, FPS ↑ {fps:.2f}")
# Only tested on A800, for the Causal Forcing++ paper latency & throughput.
# Not make claims for other hardware like H100.
# For the result on H100, refer to the reported results in the Self Forcing paper.
# We do not guarantee that our FPS/latency measurement protocol is identical to that used in the Self Forcing paper.
current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
all_video.append(current_video)
num_generated_frames += latents.shape[1]
# Final output video
clean_latent = latents[0].cpu()
video = 255.0 * torch.cat(all_video, dim=1)
# Clear VAE cache
pipeline.vae.model.clear_cache()
output_path = os.path.join(args.output_folder, f'{prompt[:100]}.mp4')
write_video(output_path, video[0], fps=16)
================================================
FILE: long_video/LICENSE
================================================
Tencent is pleased to support the community by making RollingForcing available.
Copyright (C) 2025 Tencent. All rights reserved.
The open-source software and/or models included in this distribution may have been modified by Tencent (“Tencent Modifications”). All Tencent Modifications are Copyright (C) Tencent.
RollingForcing is licensed under the License Terms of RollingForcing, except for the third-party components listed below, which remain licensed under their respective original terms. RollingForcing does not impose any additional restrictions beyond those specified in the original licenses of these third-party components. Users are required to comply with all applicable terms and conditions of the original licenses and to ensure that the use of these third-party components conforms to all relevant laws and regulations.
For the avoidance of doubt, RollingForcing refers solely to training code, inference code, parameters, and weights made publicly available by Tencent in accordance with the License Terms of RollingForcing.
Terms of the License Terms of RollingForcing:
--------------------------------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and /or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
- You agree to use RollingForcing only for academic purposes, and refrain from using it for any commercial or production purposes under any circumstances.
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Dependencies and Licenses:
This open-source project, RollingForcing, builds upon the following open-source models and/or software components, each of which remains licensed under its original license. Certain models or software may include modifications made by Tencent (“Tencent Modifications”), which are Copyright (C) Tencent.
In case you believe there have been errors in the attribution below, you may submit the concerns to us for review and correction.
Open Source Model Licensed under the Apache-2.0:
--------------------------------------------------------------------
1. Wan-AI/Wan2.1-T2V-1.3B
Copyright (c) 2025 Wan Team
Terms of the Apache-2.0:
--------------------------------------------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Definitions.
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
================================================
FILE: long_video/README.md
================================================
Builing on [Rolling Forcing](https://github.com/TencentARC/RollingForcing), we implemented minute-level long video generation.
### Installation
```bash
conda activate causal_forcing
pip install tensorboard opencv-python packaging
```
### CLI inference
```
hf download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
hf download zhuhz22/Causal-Forcing chunkwise/longvideo.pt --local-dir ../checkpoints
python inference.py \
--config_path configs/rolling_forcing_dmd.yaml \
--output_folder videos/rolling_forcing_dmd \
--checkpoint_path ../checkpoints/chunkwise/longvideo.pt \
--data_path prompts/example_prompts.txt \
--num_output_frames 252 \
--use_ema
```
### Training
```
hf download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
hf download Wan-AI/Wan2.1-T2V-14B --local-dir wan_models/Wan2.1-T2V-14B
torchrun --nproc_per_node=8 \
--rdzv_backend=c10d \
--rdzv_endpoint 127.0.0.1:29500 \
train.py \
-- \
--config_path configs/rolling_forcing_dmd.yaml \
--logdir logs/rolling_forcing_dmd
```
> We recommend training for 3000 steps.
### Acknowledge
We adopt [Rolling Forcing](https://github.com/TencentARC/RollingForcing) as our long video generation framework and only change the ODE initialization part.
================================================
FILE: long_video/app.py
================================================
import os
import argparse
import time
from typing import Optional
import torch
from torchvision.io import write_video
from omegaconf import OmegaConf
from einops import rearrange
import gradio as gr
from pipeline import CausalDiffusionInferencePipeline, CausalInferencePipeline
# -----------------------------
# Globals (loaded once per process)
# -----------------------------
_PIPELINE: Optional[torch.nn.Module] = None
_DEVICE: Optional[torch.device] = None
def _ensure_gpu():
if not torch.cuda.is_available():
raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
# Bind to GPU:0 by default
torch.cuda.set_device(0)
def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
global _PIPELINE, _DEVICE
if _PIPELINE is not None:
return _PIPELINE
_ensure_gpu()
_DEVICE = torch.device("cuda:0")
# Load and merge configs
config = OmegaConf.load(config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
# Choose pipeline type based on config
pipeline = CausalInferencePipeline(config, device=_DEVICE)
# Load checkpoint if provided
if checkpoint_path and os.path.exists(checkpoint_path):
state_dict = torch.load(checkpoint_path, map_location="cpu")
if use_ema and 'generator_ema' in state_dict:
state_dict_to_load = state_dict['generator_ema']
# Remove possible FSDP prefix
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_to_load.items():
new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
state_dict_to_load = new_state_dict
else:
state_dict_to_load = state_dict.get('generator', state_dict)
pipeline.generator.load_state_dict(state_dict_to_load, strict=False)
# The codebase assumes bfloat16 on GPU
pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
pipeline.eval()
# Quick sanity path check for Wan models to give friendly errors
wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
if not os.path.isdir(wan_dir):
raise gr.Error(
"Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
"Please download it first, e.g.:\n"
"huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
)
_PIPELINE = pipeline
return _PIPELINE
def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
os.makedirs(output_dir, exist_ok=True)
def predict(prompt: str, num_frames: int) -> str:
if not prompt or not prompt.strip():
raise gr.Error("Please enter a non-empty text prompt.")
num_frames = int(num_frames)
if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")
pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)
# Prepare inputs
prompts = [prompt.strip()]
noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)
torch.set_grad_enabled(False)
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
video = pipeline.inference_rolling_forcing(
noise=noise,
text_prompts=prompts,
return_latents=False,
initial_latent=None,
)
# video: [B=1, T, C, H, W] in [0,1]
video = rearrange(video, 'b t c h w -> b t h w c')[0]
video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()
# Save to a unique filepath
safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
ts = int(time.time())
filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
write_video(filepath, video_uint8, fps=16)
print(f"Saved generated video to {filepath}")
return filepath
return predict
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
help='Path to the model config')
parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Gradio server host')
parser.add_argument('--server_port', type=int, default=7860, help='Gradio server port')
args = parser.parse_args()
predict = build_predict(
config_path=args.config_path,
checkpoint_path=args.checkpoint_path,
output_dir=args.output_dir,
use_ema=not args.no_ema,
)
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
],
outputs=gr.Video(label="Generated Video", format="mp4"),
title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
description=(
"Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
"**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
"\n"
"If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
"You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
),
allow_flagging='never',
)
try:
# Gradio <= 3.x
demo.queue(concurrency_count=1, max_size=2)
except TypeError:
# Gradio >= 4.x
demo.queue(max_size=2)
demo.launch(server_name=args.server_name, server_port=args.server_port, show_error=True)
if __name__ == "__main__":
main()
================================================
FILE: long_video/configs/default_config.yaml
================================================
independent_first_frame: false
warp_denoising_step: false
weight_decay: 0.01
same_step_across_blocks: true
discriminator_lr_multiplier: 1.0
last_step_only: false
i2v: false
num_training_frames: 27
gc_interval: 100
context_noise: 0
causal: true
ckpt_step: 0
prompt_name: MovieGenVideoBench
prompt_path: prompts/MovieGenVideoBench.txt
eval_first_n: 64
num_samples: 1
height: 480
width: 832
num_frames: 81
================================================
FILE: long_video/configs/rolling_forcing_dmd.yaml
================================================
generator_ckpt: ../checkpoints/chunkwise/causal_ode.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
sharding_strategy: hybrid_full
lr: 1.5e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
data_path: ../prompts/vidprom_filtered_extended.txt
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 64
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
================================================
FILE: long_video/inference.py
================================================
import argparse
import torch
import os
from omegaconf import OmegaConf
from collections import OrderedDict
from tqdm import tqdm
from torchvision import transforms
from torchvision.io import write_video
from einops import rearrange
import torch.distributed as dist
import imageio
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pipeline import (
CausalDiffusionInferencePipeline,
CausalInferencePipeline
)
from utils.dataset import TextDataset, TextImagePairDataset
from utils.misc import set_seed
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to the config file")
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
parser.add_argument("--output_folder", type=str, help="Output folder")
parser.add_argument("--num_output_frames", type=int, default=21,
help="Number of overlap frames between sliding windows")
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
parser.add_argument("--save_with_index", action="store_true",
help="Whether to save the video using the index or prompt as the filename")
args = parser.parse_args()
# Initialize distributed inference
if "LOCAL_RANK" in os.environ:
dist.init_process_group(backend='nccl')
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
world_size = dist.get_world_size()
set_seed(args.seed + local_rank)
else:
device = torch.device("cuda")
local_rank = 0
world_size = 1
set_seed(args.seed)
torch.set_grad_enabled(False)
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
# Initialize pipeline
if hasattr(config, 'denoising_step_list'):
# Few-step inference
pipeline = CausalInferencePipeline(config, device=device)
else:
# Multi-step diffusion inference
pipeline = CausalDiffusionInferencePipeline(config, device=device)
if args.checkpoint_path:
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
if args.use_ema:
state_dict_to_load = state_dict['generator_ema']
def remove_fsdp_prefix(state_dict):
new_state_dict = OrderedDict()
for key, value in state_dict.items():
if "_fsdp_wrapped_module." in key:
new_key = key.replace("_fsdp_wrapped_module.", "")
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
state_dict_to_load = remove_fsdp_prefix(state_dict_to_load)
else:
state_dict_to_load = state_dict['generator']
pipeline.generator.load_state_dict(state_dict_to_load)
pipeline = pipeline.to(device=device, dtype=torch.bfloat16)
# Create dataset
if args.i2v:
assert not dist.is_initialized(), "I2V does not support distributed inference yet"
transform = transforms.Compose([
transforms.Resize((480, 832)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = TextImagePairDataset(args.data_path, transform=transform)
else:
dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
num_prompts = len(dataset)
print(f"Number of prompts: {num_prompts}")
if dist.is_initialized():
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
# Create output directory (only on main process to avoid race conditions)
if local_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
def encode(self, videos: torch.Tensor) -> torch.Tensor:
device, dtype = videos[0].device, videos[0].dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
output = [
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
for u in videos
]
output = torch.stack(output, dim=0)
return output
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
idx = batch_data['idx'].item()
# For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
# Unpack the batch data for convenience
if isinstance(batch_data, dict):
batch = batch_data
elif isinstance(batch_data, list):
batch = batch_data[0] # First (and only) item in the batch
all_video = []
num_generated_frames = 0 # Number of generated (latent) frames
if args.i2v:
# For image-to-video, batch contains image and caption
prompt = batch['prompts'][0] # Get caption from batch
prompts = [prompt] * args.num_samples
# Process the image
image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=
gitextract_cbc4vx5e/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── ar_diffusion_tf_chunkwise.yaml
│ ├── ar_diffusion_tf_framewise.yaml
│ ├── causal_cd_chunkwise.yaml
│ ├── causal_cd_framewise.yaml
│ ├── causal_forcing_dmd_chunkwise.yaml
│ ├── causal_forcing_dmd_framewise.yaml
│ ├── causal_forcing_dmd_framewise_1step.yaml
│ ├── causal_forcing_dmd_framewise_2step.yaml
│ ├── causal_ode_chunkwise.yaml
│ ├── causal_ode_framewise.yaml
│ └── default_config.yaml
├── demo.py
├── demo_utils/
│ ├── constant.py
│ ├── memory.py
│ ├── taehv.py
│ ├── utils.py
│ ├── vae.py
│ ├── vae_block3.py
│ └── vae_torch2trt.py
├── get_causal_ode_data_chunkwise.py
├── get_causal_ode_data_framewise.py
├── get_causal_ode_data_kv_optimized.py
├── inference.py
├── long_video/
│ ├── LICENSE
│ ├── README.md
│ ├── app.py
│ ├── configs/
│ │ ├── default_config.yaml
│ │ └── rolling_forcing_dmd.yaml
│ ├── inference.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── causvid.py
│ │ ├── diffusion.py
│ │ ├── dmd.py
│ │ ├── gan.py
│ │ ├── ode_regression.py
│ │ └── sid.py
│ ├── pipeline/
│ │ ├── __init__.py
│ │ ├── bidirectional_diffusion_inference.py
│ │ ├── bidirectional_inference.py
│ │ ├── causal_diffusion_inference.py
│ │ ├── rolling_forcing_inference.py
│ │ └── rolling_forcing_training.py
│ ├── prompts/
│ │ └── example_prompts.txt
│ ├── requirements.txt
│ ├── train.py
│ ├── trainer/
│ │ ├── __init__.py
│ │ ├── diffusion.py
│ │ ├── distillation.py
│ │ ├── gan.py
│ │ └── ode.py
│ ├── utils/
│ │ ├── dataset.py
│ │ ├── distributed.py
│ │ ├── lmdb.py
│ │ ├── loss.py
│ │ ├── misc.py
│ │ ├── scheduler.py
│ │ └── wan_wrapper.py
│ └── wan/
│ ├── README.md
│ ├── __init__.py
│ ├── configs/
│ │ ├── __init__.py
│ │ ├── shared_config.py
│ │ ├── wan_i2v_14B.py
│ │ ├── wan_t2v_14B.py
│ │ └── wan_t2v_1_3B.py
│ ├── distributed/
│ │ ├── __init__.py
│ │ ├── fsdp.py
│ │ └── xdit_context_parallel.py
│ ├── image2video.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── causal_model.py
│ │ ├── clip.py
│ │ ├── model.py
│ │ ├── t5.py
│ │ ├── tokenizers.py
│ │ ├── vae.py
│ │ └── xlm_roberta.py
│ ├── text2video.py
│ └── utils/
│ ├── __init__.py
│ ├── fm_solvers.py
│ ├── fm_solvers_unipc.py
│ ├── prompt_extend.py
│ ├── qwen_vl_utils.py
│ └── utils.py
├── model/
│ ├── __init__.py
│ ├── base.py
│ ├── causvid.py
│ ├── diffusion.py
│ ├── dmd.py
│ ├── gan.py
│ ├── naive_consistency.py
│ ├── ode_regression.py
│ └── sid.py
├── pipeline/
│ ├── __init__.py
│ ├── bidirectional_diffusion_inference.py
│ ├── bidirectional_inference.py
│ ├── bidirectional_training.py
│ ├── causal_diffusion_inference.py
│ ├── causal_inference.py
│ ├── self_forcing_training.py
│ └── teacher_forcing_training.py
├── prompts/
│ ├── demos.txt
│ └── i2v/
│ └── target_crop_info_26-15.json
├── requirements.txt
├── setup.py
├── train.py
├── trainer/
│ ├── __init__.py
│ ├── diffusion.py
│ ├── distillation.py
│ ├── gan.py
│ ├── naive_cd.py
│ └── ode.py
├── utils/
│ ├── create_lmdb_iterative.py
│ ├── dataset.py
│ ├── distributed.py
│ ├── lmdb_.py
│ ├── loss.py
│ ├── merge_and_get_clean.py
│ ├── merge_lmdb.py
│ ├── misc.py
│ ├── ode_generation.py
│ ├── scheduler.py
│ └── wan_wrapper.py
└── wan/
├── README.md
├── __init__.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── causal_model.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae.py
│ └── xlm_roberta.py
├── text2video.py
└── utils/
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
└── utils.py
SYMBOL INDEX (1104 symbols across 103 files)
FILE: demo.py
function initialize_vae_decoder (line 60) | def initialize_vae_decoder(use_taehv=False, use_trt=False):
function tensor_to_base64_frame (line 162) | def tensor_to_base64_frame(frame_tensor):
function frame_sender_worker (line 190) | def frame_sender_worker():
function generate_video_stream (line 242) | def generate_video_stream(prompt, seed, enable_torch_compile=False, enab...
function generate_mp4_from_images (line 509) | def generate_mp4_from_images(image_directory, output_video_path, fps=24):
function calculate_sha256 (line 533) | def calculate_sha256(data):
function handle_connect (line 543) | def handle_connect():
function handle_disconnect (line 549) | def handle_disconnect():
function handle_start_generation (line 554) | def handle_start_generation(data):
function handle_stop_generation (line 597) | def handle_stop_generation():
function index (line 614) | def index():
function api_status (line 619) | def api_status():
FILE: demo_utils/memory.py
class DynamicSwapInstaller (line 13) | class DynamicSwapInstaller:
method _install_module (line 15) | def _install_module(module: torch.nn.Module, **kwargs):
method _uninstall_module (line 43) | def _uninstall_module(module: torch.nn.Module):
method install_model (line 49) | def install_model(model: torch.nn.Module, **kwargs):
method uninstall_model (line 55) | def uninstall_model(model: torch.nn.Module):
function fake_diffusers_current_device (line 61) | def fake_diffusers_current_device(model: torch.nn.Module, target_device:...
function get_cuda_free_memory_gb (line 72) | def get_cuda_free_memory_gb(device=None):
function move_model_to_device_with_memory_preservation (line 85) | def move_model_to_device_with_memory_preservation(model, target_device, ...
function offload_model_from_device_for_memory_preservation (line 101) | def offload_model_from_device_for_memory_preservation(model, target_devi...
function unload_complete_models (line 117) | def unload_complete_models(*args):
function load_model_as_complete (line 127) | def load_model_as_complete(model, target_device, unload=True):
FILE: demo_utils/taehv.py
function conv (line 16) | def conv(n_in, n_out, **kwargs):
class Clamp (line 20) | class Clamp(nn.Module):
method forward (line 21) | def forward(self, x):
class MemBlock (line 25) | class MemBlock(nn.Module):
method __init__ (line 26) | def __init__(self, n_in, n_out):
method forward (line 33) | def forward(self, x, past):
class TPool (line 37) | class TPool(nn.Module):
method __init__ (line 38) | def __init__(self, n_f, stride):
method forward (line 43) | def forward(self, x):
class TGrow (line 48) | class TGrow(nn.Module):
method __init__ (line 49) | def __init__(self, n_f, stride):
method forward (line 54) | def forward(self, x):
function apply_model_with_memblocks (line 60) | def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
class TAEHV (line 159) | class TAEHV(nn.Module):
method __init__ (line 163) | def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(...
method patch_tgrow_layers (line 195) | def patch_tgrow_layers(self, sd):
method encode_video (line 210) | def encode_video(self, x, parallel=True, show_progress_bar=True):
method decode_video (line 222) | def decode_video(self, x, parallel=True, show_progress_bar=False):
method forward (line 236) | def forward(self, x):
function main (line 241) | def main():
FILE: demo_utils/utils.py
function min_resize (line 19) | def min_resize(x, m):
function d_resize (line 36) | def d_resize(x, y):
function resize_and_center_crop (line 48) | def resize_and_center_crop(image, target_width, target_height):
function resize_and_center_crop_pytorch (line 66) | def resize_and_center_crop_pytorch(image, target_width, target_height):
function resize_without_crop (line 85) | def resize_without_crop(image, target_width, target_height):
function just_crop (line 94) | def just_crop(image, w, h):
function write_to_json (line 108) | def write_to_json(data, file_path):
function read_from_json (line 116) | def read_from_json(file_path):
function get_active_parameters (line 122) | def get_active_parameters(m):
function cast_training_params (line 126) | def cast_training_params(m, dtype=torch.float32):
function separate_lora_AB (line 135) | def separate_lora_AB(parameters, B_patterns=None):
function set_attr_recursive (line 151) | def set_attr_recursive(obj, attr, value):
function print_tensor_list_size (line 159) | def print_tensor_list_size(tensors):
function batch_mixture (line 180) | def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
function zero_module (line 196) | def zero_module(module):
function supress_lower_channels (line 203) | def supress_lower_channels(m, k, alpha=0.01):
function freeze_module (line 213) | def freeze_module(m):
function get_latest_safetensors (line 221) | def get_latest_safetensors(folder_path):
function generate_random_prompt_from_tags (line 232) | def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=...
function interpolate_numbers (line 239) | def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
function uniform_random_by_intervals (line 246) | def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=Fa...
function soft_append_bcthw (line 255) | def soft_append_bcthw(history, current, overlap=0):
function save_bcthw_as_mp4 (line 269) | def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
function save_bcthw_as_png (line 286) | def save_bcthw_as_png(x, output_filename):
function save_bchw_as_png (line 295) | def save_bchw_as_png(x, output_filename):
function add_tensors_with_padding (line 304) | def add_tensors_with_padding(tensor1, tensor2):
function print_free_mem (line 323) | def print_free_mem():
function print_gpu_parameters (line 333) | def print_gpu_parameters(device, state_dict, log_count=1):
function visualize_txt_as_img (line 348) | def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans...
function blue_mark (line 386) | def blue_mark(x):
function green_mark (line 394) | def green_mark(x):
function frame_mark (line 401) | def frame_mark(x):
function pytorch2numpy (line 411) | def pytorch2numpy(imgs):
function numpy2pytorch (line 422) | def numpy2pytorch(imgs):
function duplicate_prefix_to_suffix (line 429) | def duplicate_prefix_to_suffix(x, count, zero_out=False):
function weighted_mse (line 436) | def weighted_mse(a, b, weight):
function clamped_linear_interpolation (line 440) | def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
function expand_to_dims (line 447) | def expand_to_dims(x, target_dims):
function repeat_to_batch_size (line 451) | def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
function dim5 (line 468) | def dim5(x):
function dim4 (line 472) | def dim4(x):
function dim3 (line 476) | def dim3(x):
function crop_or_pad_yield_mask (line 480) | def crop_or_pad_yield_mask(x, length):
function extend_dim (line 495) | def extend_dim(x, dim, minimal_length, zero_pad=False):
function lazy_positional_encoding (line 513) | def lazy_positional_encoding(t, repeats=None):
function state_dict_offset_merge (line 530) | def state_dict_offset_merge(A, B, C=None):
function state_dict_weighted_merge (line 547) | def state_dict_weighted_merge(state_dicts, weights):
function group_files_by_folder (line 574) | def group_files_by_folder(all_files):
function generate_timestamp (line 587) | def generate_timestamp():
function write_PIL_image_with_png_info (line 595) | def write_PIL_image_with_png_info(image, metadata, path):
function torch_safe_save (line 606) | def torch_safe_save(content, path):
function move_optimizer_to_device (line 612) | def move_optimizer_to_device(optimizer, device):
FILE: demo_utils/vae.py
class ResidualBlock (line 13) | class ResidualBlock(nn.Module):
method __init__ (line 15) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 29) | def forward(self, x, feat_cache_1, feat_cache_2):
class Resample (line 51) | class Resample(nn.Module):
method __init__ (line 53) | def __init__(self, dim, mode):
method forward (line 73) | def forward(self, x, is_first_frame, feat_cache):
method temporal_conv (line 105) | def temporal_conv(self, x, is_first_frame, feat_cache):
method init_weight (line 127) | def init_weight(self, conv):
method init_weight2 (line 139) | def init_weight2(self, conv):
class VAEDecoderWrapperSingle (line 151) | class VAEDecoderWrapperSingle(nn.Module):
method __init__ (line 152) | def __init__(self):
method forward (line 168) | def forward(
class VAEDecoder3d (line 199) | class VAEDecoder3d(nn.Module):
method __init__ (line 200) | def __init__(self,
method forward (line 254) | def forward(
class VAETRTWrapper (line 318) | class VAETRTWrapper():
method __init__ (line 319) | def __init__(self):
method quantize_if_needed (line 376) | def quantize_if_needed(self, t, expected_dtype, scale):
method forward (line 381) | def forward(self, *test_inputs):
FILE: demo_utils/vae_block3.py
class Resample (line 9) | class Resample(nn.Module):
method __init__ (line 11) | def __init__(self, dim, mode):
method forward (line 45) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 106) | def init_weight(self, conv):
method init_weight2 (line 118) | def init_weight2(self, conv):
class VAEDecoderWrapper (line 130) | class VAEDecoderWrapper(nn.Module):
method __init__ (line 131) | def __init__(self):
method forward (line 147) | def forward(
class VAEDecoder3d (line 187) | class VAEDecoder3d(nn.Module):
method __init__ (line 188) | def __init__(self,
method forward (line 242) | def forward(
FILE: demo_utils/vae_torch2trt.py
function set_workspace (line 98) | def set_workspace(config, bytes_):
function set_workspace (line 122) | def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
class VAECalibrator (line 139) | class VAECalibrator(trt.IInt8EntropyCalibrator2):
method __init__ (line 140) | def __init__(self, loader, cache="calibration.cache", max_batches=10):
method get_batch_size (line 151) | def get_batch_size(self):
method getBatchSize (line 154) | def getBatchSize(self):
method get_batch (line 157) | def get_batch(self, names):
method read_calibration_cache (line 202) | def read_calibration_cache(self):
method readCalibrationCache (line 209) | def readCalibrationCache(self):
method write_calibration_cache (line 212) | def write_calibration_cache(self, cache):
method writeCalibrationCache (line 216) | def writeCalibrationCache(self, cache):
FILE: get_causal_ode_data_chunkwise.py
function init_model (line 13) | def init_model(device):
function main (line 32) | def main():
FILE: get_causal_ode_data_framewise.py
function init_model (line 13) | def init_model(device):
function main (line 32) | def main():
FILE: get_causal_ode_data_kv_optimized.py
function normalize_generator_state_dict (line 31) | def normalize_generator_state_dict(state_dict: dict) -> dict:
function init_model (line 47) | def init_model(
function prepare_clean_latent (line 73) | def prepare_clean_latent(
function main (line 92) | def main():
FILE: inference.py
function encode (line 125) | def encode(self, videos: torch.Tensor) -> torch.Tensor:
FILE: long_video/app.py
function _ensure_gpu (line 22) | def _ensure_gpu():
function _load_pipeline (line 29) | def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use...
function build_predict (line 78) | def build_predict(config_path: str, checkpoint_path: Optional[str], outp...
function main (line 120) | def main():
FILE: long_video/inference.py
function remove_fsdp_prefix (line 70) | def remove_fsdp_prefix(state_dict):
function encode (line 114) | def encode(self, videos: torch.Tensor) -> torch.Tensor:
FILE: long_video/model/base.py
class BaseModel (line 12) | class BaseModel(nn.Module):
method __init__ (line 13) | def __init__(self, args, device):
method _initialize_models (line 26) | def _initialize_models(self, args, device):
method _get_timestep (line 48) | def _get_timestep(
class RollingForcingModel (line 98) | class RollingForcingModel(BaseModel):
method __init__ (line 99) | def __init__(self, args, device):
method _run_generator (line 103) | def _run_generator(
method _consistency_backward_simulation (line 182) | def _consistency_backward_simulation(
method _initialize_inference_pipeline (line 214) | def _initialize_inference_pipeline(self):
FILE: long_video/model/causvid.py
class CausVid (line 8) | class CausVid(BaseModel):
method __init__ (line 9) | def __init__(self, args, device):
method _compute_kl_grad (line 47) | def _compute_kl_grad(
method compute_distribution_matching_loss (line 121) | def compute_distribution_matching_loss(
method _run_generator (line 184) | def _run_generator(
method generator_loss (line 255) | def generator_loss(
method critic_loss (line 296) | def critic_loss(
FILE: long_video/model/diffusion.py
class CausalDiffusion (line 8) | class CausalDiffusion(BaseModel):
method __init__ (line 9) | def __init__(self, args, device):
method _initialize_models (line 34) | def _initialize_models(self, args):
method generator_loss (line 44) | def generator_loss(
FILE: long_video/model/dmd.py
class DMD (line 9) | class DMD(RollingForcingModel):
method __init__ (line 10) | def __init__(self, args, device):
method _compute_kl_grad (line 54) | def _compute_kl_grad(
method compute_distribution_matching_loss (line 128) | def compute_distribution_matching_loss(
method generator_loss (line 196) | def generator_loss(
method critic_loss (line 237) | def critic_loss(
FILE: long_video/model/gan.py
class GAN (line 10) | class GAN(RollingForcingModel):
method __init__ (line 11) | def __init__(self, args, device):
method _run_cls_pred_branch (line 69) | def _run_cls_pred_branch(self,
method generator_loss (line 90) | def generator_loss(
method critic_loss (line 174) | def critic_loss(
FILE: long_video/model/ode_regression.py
class ODERegression (line 9) | class ODERegression(BaseModel):
method __init__ (line 10) | def __init__(self, args, device):
method _initialize_models (line 45) | def _initialize_models(self, args, device):
method _prepare_generator_input (line 59) | def _prepare_generator_input(self, ode_latent: torch.Tensor, tf=False,...
method generator_loss (line 95) | def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: d...
FILE: long_video/model/sid.py
class SiD (line 8) | class SiD(RollingForcingModel):
method __init__ (line 9) | def __init__(self, args, device):
method compute_distribution_matching_loss (line 47) | def compute_distribution_matching_loss(
method generator_loss (line 147) | def generator_loss(
method critic_loss (line 188) | def critic_loss(
FILE: long_video/pipeline/bidirectional_diffusion_inference.py
class BidirectionalDiffusionInferencePipeline (line 10) | class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 34) | def inference(
method _initialize_sample_scheduler (line 89) | def _initialize_sample_scheduler(self, noise):
FILE: long_video/pipeline/bidirectional_inference.py
class BidirectionalInferencePipeline (line 7) | class BidirectionalInferencePipeline(torch.nn.Module):
method __init__ (line 8) | def __init__(
method inference (line 33) | def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> t...
FILE: long_video/pipeline/causal_diffusion_inference.py
class CausalDiffusionInferencePipeline (line 10) | class CausalDiffusionInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 49) | def inference(
method _initialize_kv_cache (line 270) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 300) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
method _initialize_sample_scheduler (line 321) | def _initialize_sample_scheduler(self, noise):
FILE: long_video/pipeline/rolling_forcing_inference.py
class CausalInferencePipeline (line 7) | class CausalInferencePipeline(torch.nn.Module):
method __init__ (line 8) | def __init__(
method inference_rolling_forcing (line 45) | def inference_rolling_forcing(
method _initialize_kv_cache (line 338) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 360) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
FILE: long_video/pipeline/rolling_forcing_training.py
class RollingForcingTrainingPipeline (line 8) | class RollingForcingTrainingPipeline:
method __init__ (line 9) | def __init__(self,
method generate_and_sync_list (line 41) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
method generate_list (line 60) | def generate_list(self, num_blocks, num_denoising_steps, device):
method inference_with_rolling_forcing (line 75) | def inference_with_rolling_forcing(
method inference_with_self_forcing (line 256) | def inference_with_self_forcing(
method _initialize_kv_cache (line 436) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 452) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
FILE: long_video/train.py
function main (line 8) | def main():
FILE: long_video/trainer/diffusion.py
class Trainer (line 17) | class Trainer:
method __init__ (line 18) | def __init__(self, config):
method save (line 140) | def save(self):
method train_one_step (line 163) | def train_one_step(self, batch):
method generate_video (line 235) | def generate_video(self, pipeline, prompts, image=None):
method train (line 248) | def train(self):
FILE: long_video/trainer/distillation.py
class Trainer (line 20) | class Trainer:
method __init__ (line 21) | def __init__(self, config):
method save (line 176) | def save(self):
method fwdbwd_one_step (line 203) | def fwdbwd_one_step(self, batch, train_generator):
method generate_video (line 276) | def generate_video(self, pipeline, prompts, image=None):
method train (line 306) | def train(self):
FILE: long_video/trainer/gan.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, config):
method save (line 208) | def save(self):
method fwdbwd_one_step (line 235) | def fwdbwd_one_step(self, batch, train_generator):
method generate_video (line 324) | def generate_video(self, pipeline, prompts, image=None):
method train (line 337) | def train(self):
method all_gather_dict (line 457) | def all_gather_dict(self, target_dict):
FILE: long_video/trainer/ode.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, config):
method save (line 118) | def save(self):
method train_one_step (line 134) | def train_one_step(self):
method train (line 225) | def train(self):
FILE: long_video/utils/dataset.py
class TextDataset (line 12) | class TextDataset(Dataset):
method __init__ (line 13) | def __init__(self, prompt_path, extended_prompt_path=None):
method __len__ (line 24) | def __len__(self):
method __getitem__ (line 27) | def __getitem__(self, idx):
class ODERegressionLMDBDataset (line 37) | class ODERegressionLMDBDataset(Dataset):
method __init__ (line 38) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 45) | def __len__(self):
method __getitem__ (line 48) | def __getitem__(self, idx):
class ShardingLMDBDataset (line 72) | class ShardingLMDBDataset(Dataset):
method __init__ (line 73) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 96) | def __len__(self):
method __getitem__ (line 99) | def __getitem__(self, idx):
class TextImagePairDataset (line 127) | class TextImagePairDataset(Dataset):
method __init__ (line 128) | def __init__(
method __len__ (line 182) | def __len__(self):
method __getitem__ (line 185) | def __getitem__(self, idx):
function cycle (line 217) | def cycle(dl):
FILE: long_video/utils/distributed.py
function fsdp_state_dict (line 11) | def fsdp_state_dict(model):
function fsdp_wrap (line 23) | def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, w...
function barrier (line 70) | def barrier():
function launch_distributed_job (line 75) | def launch_distributed_job(backend: str = "nccl"):
class EMA_FSDP (line 91) | class EMA_FSDP:
method __init__ (line 92) | def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
method _init_shadow (line 98) | def _init_shadow(self, fsdp_module):
method update (line 105) | def update(self, fsdp_module):
method state_dict (line 113) | def state_dict(self):
method load_state_dict (line 116) | def load_state_dict(self, sd):
method copy_to (line 119) | def copy_to(self, fsdp_module):
FILE: long_video/utils/lmdb.py
function get_array_shape_from_lmdb (line 4) | def get_array_shape_from_lmdb(env, array_name):
function store_arrays_to_lmdb (line 11) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
function process_data_dict (line 30) | def process_data_dict(data_dict, seen_prompts):
function retrieve_row_from_lmdb (line 56) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...
FILE: long_video/utils/loss.py
class DenoisingLoss (line 5) | class DenoisingLoss(ABC):
method __call__ (line 7) | def __call__(
class X0PredLoss (line 27) | class X0PredLoss(DenoisingLoss):
method __call__ (line 28) | def __call__(
class VPredLoss (line 38) | class VPredLoss(DenoisingLoss):
method __call__ (line 39) | def __call__(
class NoisePredLoss (line 50) | class NoisePredLoss(DenoisingLoss):
method __call__ (line 51) | def __call__(
class FlowPredLoss (line 61) | class FlowPredLoss(DenoisingLoss):
method __call__ (line 62) | def __call__(
function get_denoising_loss (line 80) | def get_denoising_loss(loss_type: str) -> DenoisingLoss:
FILE: long_video/utils/misc.py
function set_seed (line 6) | def set_seed(seed: int, deterministic: bool = False):
function merge_dict_list (line 25) | def merge_dict_list(dict_list):
FILE: long_video/utils/scheduler.py
class SchedulerInterface (line 5) | class SchedulerInterface(ABC):
method add_noise (line 12) | def add_noise(
method convert_x0_to_noise (line 26) | def convert_x0_to_noise(
method convert_noise_to_x0 (line 52) | def convert_noise_to_x0(
method convert_velocity_to_x0 (line 77) | def convert_velocity_to_x0(
class FlowMatchScheduler (line 106) | class FlowMatchScheduler():
method __init__ (line 108) | def __init__(self, num_inference_steps=100, num_train_timesteps=1000, ...
method set_timesteps (line 118) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
method step (line 143) | def step(self, model_output, timestep, sample, to_final=False):
method add_noise (line 159) | def add_noise(self, original_samples, noise, timestep):
method training_target (line 178) | def training_target(self, sample, noise, timestep):
method training_weight (line 182) | def training_weight(self, timestep):
FILE: long_video/utils/wan_wrapper.py
class WanTextEncoder (line 14) | class WanTextEncoder(torch.nn.Module):
method __init__ (line 15) | def __init__(self) -> None:
method device (line 33) | def device(self):
method forward (line 37) | def forward(self, text_prompts: List[str]) -> dict:
class WanVAEWrapper (line 53) | class WanVAEWrapper(torch.nn.Module):
method __init__ (line 54) | def __init__(self):
method encode_to_latent (line 73) | def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
method decode_to_pixel (line 89) | def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = Fals...
class WanDiffusionWrapper (line 115) | class WanDiffusionWrapper(torch.nn.Module):
method __init__ (line 116) | def __init__(
method enable_gradient_checkpointing (line 144) | def enable_gradient_checkpointing(self) -> None:
method adding_cls_branch (line 147) | def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_di...
method _convert_flow_pred_to_x0 (line 169) | def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch....
method _convert_x0_to_flow_pred (line 196) | def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: tor...
method forward (line 218) | def forward(
method get_scheduler (line 293) | def get_scheduler(self) -> SchedulerInterface:
method post_init (line 307) | def post_init(self):
FILE: long_video/wan/distributed/fsdp.py
function shard_model (line 10) | def shard_model(
FILE: long_video/wan/distributed/xdit_context_parallel.py
function pad_freqs (line 12) | def pad_freqs(original_tensor, target_len):
function rope_apply (line 26) | def rope_apply(x, grid_sizes, freqs):
function usp_dit_forward (line 66) | def usp_dit_forward(
function usp_attn_forward (line 149) | def usp_attn_forward(self,
FILE: long_video/wan/image2video.py
class WanI2V (line 29) | class WanI2V:
method __init__ (line 31) | def __init__(
method generate (line 129) | def generate(self,
FILE: long_video/wan/modules/attention.py
function is_hopper_gpu (line 7) | def is_hopper_gpu():
function flash_attention (line 32) | def flash_attention(
function attention (line 139) | def attention(
FILE: long_video/wan/modules/causal_model.py
function causal_rope_apply (line 27) | def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
class CausalWanSelfAttention (line 58) | class CausalWanSelfAttention(nn.Module):
method __init__ (line 60) | def __init__(self,
method forward (line 87) | def forward(
class CausalWanAttentionBlock (line 309) | class CausalWanAttentionBlock(nn.Module):
method __init__ (line 311) | def __init__(self,
method forward (line 349) | def forward(
class CausalHead (line 405) | class CausalHead(nn.Module):
method __init__ (line 407) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 422) | def forward(self, x, e):
class CausalWanModel (line 436) | class CausalWanModel(ModelMixin, ConfigMixin):
method __init__ (line 448) | def __init__(self,
method _set_gradient_checkpointing (line 569) | def _set_gradient_checkpointing(self, module, value=False):
method _prepare_blockwise_causal_attn_mask (line 573) | def _prepare_blockwise_causal_attn_mask(
method _prepare_teacher_forcing_mask (line 631) | def _prepare_teacher_forcing_mask(
method _prepare_blockwise_causal_attn_mask_i2v (line 719) | def _prepare_blockwise_causal_attn_mask_i2v(
method _forward_inference (line 779) | def _forward_inference(
method _forward_train (line 912) | def _forward_train(
method forward (line 1070) | def forward(
method unpatchify (line 1080) | def unpatchify(self, x, grid_sizes):
method init_weights (line 1105) | def init_weights(self):
FILE: long_video/wan/modules/clip.py
function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
class QuickGELU (line 41) | class QuickGELU(nn.Module):
method forward (line 43) | def forward(self, x):
class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
method forward (line 49) | def forward(self, x):
class SelfAttention (line 53) | class SelfAttention(nn.Module):
method __init__ (line 55) | def __init__(self,
method forward (line 74) | def forward(self, x):
class SwiGLU (line 94) | class SwiGLU(nn.Module):
method __init__ (line 96) | def __init__(self, dim, mid_dim):
method forward (line 106) | def forward(self, x):
class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
method __init__ (line 114) | def __init__(self,
method forward (line 146) | def forward(self, x):
class AttentionPool (line 156) | class AttentionPool(nn.Module):
method __init__ (line 158) | def __init__(self,
method forward (line 186) | def forward(self, x):
class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
method __init__ (line 211) | def __init__(self,
method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
method __init__ (line 305) | def __init__(self, **kwargs):
method forward (line 315) | def forward(self, ids):
class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
method __init__ (line 330) | def __init__(self,
method forward (line 406) | def forward(self, imgs, txt_ids):
method param_groups (line 418) | def param_groups(self):
function _clip (line 434) | def _clip(pretrained=False,
function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
class CLIPModel (line 501) | class CLIPModel:
method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
method visual (line 527) | def visual(self, videos):
FILE: long_video/wan/modules/model.py
function sinusoidal_embedding_1d (line 15) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 29) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 40) | def rope_apply(x, grid_sizes, freqs):
class WanRMSNorm (line 70) | class WanRMSNorm(nn.Module):
method __init__ (line 72) | def __init__(self, dim, eps=1e-5):
method forward (line 78) | def forward(self, x):
method _norm (line 85) | def _norm(self, x):
class WanLayerNorm (line 89) | class WanLayerNorm(nn.LayerNorm):
method __init__ (line 91) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 94) | def forward(self, x):
class WanSelfAttention (line 102) | class WanSelfAttention(nn.Module):
method __init__ (line 104) | def __init__(self,
method forward (line 127) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanT2VCrossAttention (line 159) | class WanT2VCrossAttention(WanSelfAttention):
method forward (line 161) | def forward(self, x, context, context_lens, crossattn_cache=None):
class WanGanCrossAttention (line 197) | class WanGanCrossAttention(WanSelfAttention):
method forward (line 199) | def forward(self, x, context, crossattn_cache=None):
class WanI2VCrossAttention (line 224) | class WanI2VCrossAttention(WanSelfAttention):
method __init__ (line 226) | def __init__(self,
method forward (line 240) | def forward(self, x, context, context_lens):
class WanAttentionBlock (line 275) | class WanAttentionBlock(nn.Module):
method __init__ (line 277) | def __init__(self,
method forward (line 315) | def forward(
class GanAttentionBlock (line 357) | class GanAttentionBlock(nn.Module):
method __init__ (line 359) | def __init__(self,
method forward (line 397) | def forward(
class Head (line 439) | class Head(nn.Module):
method __init__ (line 441) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 456) | def forward(self, x, e):
class MLPProj (line 469) | class MLPProj(torch.nn.Module):
method __init__ (line 471) | def __init__(self, in_dim, out_dim):
method forward (line 479) | def forward(self, image_embeds):
class RegisterTokens (line 484) | class RegisterTokens(nn.Module):
method __init__ (line 485) | def __init__(self, num_registers: int, dim: int):
method forward (line 490) | def forward(self):
method reset_parameters (line 493) | def reset_parameters(self):
class WanModel (line 497) | class WanModel(ModelMixin, ConfigMixin):
method __init__ (line 509) | def __init__(self,
method _set_gradient_checkpointing (line 623) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 626) | def forward(
method _forward (line 637) | def _forward(
method _forward_classify (line 773) | def _forward_classify(
method unpatchify (line 876) | def unpatchify(self, x, grid_sizes, c=None):
method init_weights (line 901) | def init_weights(self):
FILE: long_video/wan/modules/t5.py
function fp16_clamp (line 20) | def fp16_clamp(x):
function init_weights (line 27) | def init_weights(m):
class GELU (line 46) | class GELU(nn.Module):
method forward (line 48) | def forward(self, x):
class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
method forward (line 61) | def forward(self, x):
class T5Attention (line 69) | class T5Attention(nn.Module):
method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
method forward (line 136) | def forward(self, x):
class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
method __init__ (line 146) | def __init__(self,
method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
method __init__ (line 180) | def __init__(self,
method forward (line 206) | def forward(self,
class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
method forward (line 233) | def forward(self, lq, lk):
method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
class T5Encoder (line 267) | class T5Encoder(nn.Module):
method __init__ (line 269) | def __init__(self,
method forward (line 303) | def forward(self, ids, mask=None):
class T5Decoder (line 315) | class T5Decoder(nn.Module):
method __init__ (line 317) | def __init__(self,
method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
class T5Model (line 372) | class T5Model(nn.Module):
method __init__ (line 374) | def __init__(self,
method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
function _t5 (line 415) | def _t5(name,
function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
class T5EncoderModel (line 472) | class T5EncoderModel:
method __init__ (line 474) | def __init__(
method __call__ (line 506) | def __call__(self, texts, device):
FILE: long_video/wan/modules/tokenizers.py
function basic_clean (line 12) | def basic_clean(text):
function whitespace_clean (line 18) | def whitespace_clean(text):
function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
method __call__ (line 49) | def __call__(self, sequence, **kwargs):
method _clean (line 75) | def _clean(self, text):
FILE: long_video/wan/modules/vae.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 28) | def forward(self, x, cache_x=None):
class RMS_norm (line 39) | class RMS_norm(nn.Module):
method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 51) | def forward(self, x):
class Upsample (line 57) | class Upsample(nn.Upsample):
method forward (line 59) | def forward(self, x):
class Resample (line 66) | class Resample(nn.Module):
method __init__ (line 68) | def __init__(self, dim, mode):
method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 162) | def init_weight(self, conv):
method init_weight2 (line 174) | def init_weight2(self, conv):
class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
method __init__ (line 228) | def __init__(self, dim):
method forward (line 240) | def forward(self, x):
class Encoder3d (line 265) | class Encoder3d(nn.Module):
method __init__ (line 267) | def __init__(self,
method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 369) | class Decoder3d(nn.Module):
method __init__ (line 371) | def __init__(self,
method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
function count_conv3d (line 475) | def count_conv3d(model):
class WanVAE_ (line 483) | class WanVAE_(nn.Module):
method __init__ (line 485) | def __init__(self,
method forward (line 511) | def forward(self, x):
method encode (line 517) | def encode(self, x, scale):
method decode (line 545) | def decode(self, z, scale):
method cached_decode (line 571) | def cached_decode(self, z, scale):
method sample (line 595) | def sample(self, imgs, deterministic=False):
method clear_cache (line 602) | def clear_cache(self):
function _video_vae (line 612) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
class WanVAE (line 639) | class WanVAE:
method __init__ (line 641) | def __init__(self,
method encode (line 667) | def encode(self, videos):
method decode (line 677) | def decode(self, zs):
FILE: long_video/wan/modules/xlm_roberta.py
class SelfAttention (line 10) | class SelfAttention(nn.Module):
method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
method forward (line 27) | def forward(self, x, mask):
class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
method forward (line 66) | def forward(self, x, mask):
class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
method __init__ (line 81) | def __init__(self,
method forward (line 118) | def forward(self, ids):
function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,
FILE: long_video/wan/text2video.py
class WanT2V (line 26) | class WanT2V:
method __init__ (line 28) | def __init__(
method generate (line 110) | def generate(self,
FILE: long_video/wan/utils/fm_solvers.py
function get_sampling_sigmas (line 22) | def get_sampling_sigmas(sampling_steps, shift):
function retrieve_timesteps (line 29) | def retrieve_timesteps(
class FlowDPMSolverMultistepScheduler (line 69) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 129) | def __init__(
method step_index (line 202) | def step_index(self):
method begin_index (line 209) | def begin_index(self):
method set_begin_index (line 216) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 226) | def set_timesteps(
method _threshold_sample (line 292) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 330) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 333) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 337) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 341) | def convert_model_output(
method dpm_solver_first_order_update (line 415) | def dpm_solver_first_order_update(
method multistep_dpm_solver_second_order_update (line 486) | def multistep_dpm_solver_second_order_update(
method multistep_dpm_solver_third_order_update (line 596) | def multistep_dpm_solver_third_order_update(
method index_for_timestep (line 679) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 693) | def _init_step_index(self, timestep):
method step (line 706) | def step(
method scale_model_input (line 800) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 815) | def add_noise(
method __len__ (line 856) | def __len__(self):
FILE: long_video/wan/utils/fm_solvers_unipc.py
class FlowUniPCMultistepScheduler (line 20) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 77) | def __init__(
method step_index (line 135) | def step_index(self):
method begin_index (line 142) | def begin_index(self):
method set_begin_index (line 149) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 160) | def set_timesteps(
method _threshold_sample (line 230) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 269) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 272) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 276) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 279) | def convert_model_output(
method multistep_uni_p_bh_update (line 350) | def multistep_uni_p_bh_update(
method multistep_uni_c_bh_update (line 486) | def multistep_uni_c_bh_update(
method index_for_timestep (line 628) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 643) | def _init_step_index(self, timestep):
method step (line 655) | def step(self,
method scale_model_input (line 741) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 758) | def add_noise(
method __len__ (line 799) | def __len__(self):
FILE: long_video/wan/utils/prompt_extend.py
class PromptOutput (line 101) | class PromptOutput(object):
method add_custom_field (line 108) | def add_custom_field(self, key: str, value) -> None:
class PromptExpander (line 112) | class PromptExpander:
method __init__ (line 114) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
method extend_with_img (line 119) | def extend_with_img(self,
method extend (line 128) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method decide_system_prompt (line 131) | def decide_system_prompt(self, tar_lang="ch"):
method __call__ (line 138) | def __call__(self,
class DashScopePromptExpander (line 157) | class DashScopePromptExpander(PromptExpander):
method __init__ (line 159) | def __init__(self,
method extend (line 196) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 232) | def extend_with_img(self,
class QwenPromptExpander (line 300) | class QwenPromptExpander(PromptExpander):
method __init__ (line 309) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
method extend (line 366) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 397) | def extend_with_img(self,
FILE: long_video/wan/utils/qwen_vl_utils.py
function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
function smart_resize (line 54) | def smart_resize(height: int,
function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
function smart_nframes (line 133) | def smart_nframes(
function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
function is_decord_available (line 215) | def is_decord_available() -> bool:
function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
function fetch_video (line 274) | def fetch_video(
function extract_vision_info (line 328) | def extract_vision_info(
function process_vision_info (line 344) | def process_vision_info(
FILE: long_video/wan/utils/utils.py
function rand_name (line 14) | def rand_name(length=8, suffix=''):
function cache_video (line 23) | def cache_video(tensor,
function cache_image (line 64) | def cache_image(tensor,
function str2bool (line 94) | def str2bool(v):
FILE: model/base.py
class BaseModel (line 12) | class BaseModel(nn.Module):
method __init__ (line 13) | def __init__(self, args, device):
method _initialize_models (line 40) | def _initialize_models(self, args, device):
method _get_timestep (line 63) | def _get_timestep(
class SelfForcingModel (line 113) | class SelfForcingModel(BaseModel):
method __init__ (line 114) | def __init__(self, args, device):
method _run_generator (line 118) | def _run_generator(
method _consistency_backward_simulation (line 205) | def _consistency_backward_simulation(
method _initialize_inference_pipeline (line 229) | def _initialize_inference_pipeline(self):
class TeacherForcingModel (line 251) | class TeacherForcingModel(BaseModel):
method __init__ (line 252) | def __init__(self, args, device):
method _run_generator (line 256) | def _run_generator(
method _consistency_backward_simulation_tf (line 344) | def _consistency_backward_simulation_tf(
method _initialize_inference_pipeline_tf (line 372) | def _initialize_inference_pipeline_tf(self):
class BidirectionalModel (line 391) | class BidirectionalModel(BaseModel):
method __init__ (line 392) | def __init__(self, args, device):
method _run_generator (line 396) | def _run_generator(
method _consistency_backward_simulation_bidirectional (line 476) | def _consistency_backward_simulation_bidirectional(
method _initialize_inference_pipeline_bidirectional (line 502) | def _initialize_inference_pipeline_bidirectional(self):
FILE: model/causvid.py
class CausVid (line 8) | class CausVid(BaseModel):
method __init__ (line 9) | def __init__(self, args, device):
method _compute_kl_grad (line 47) | def _compute_kl_grad(
method compute_distribution_matching_loss (line 121) | def compute_distribution_matching_loss(
method _run_generator (line 184) | def _run_generator(
method generator_loss (line 256) | def generator_loss(
method critic_loss (line 297) | def critic_loss(
FILE: model/diffusion.py
class CausalDiffusion (line 8) | class CausalDiffusion(BaseModel):
method __init__ (line 9) | def __init__(self, args, device):
method _initialize_models (line 35) | def _initialize_models(self, args, device):
method generator_loss (line 48) | def generator_loss(
FILE: model/dmd.py
class DMD (line 9) | class DMD(SelfForcingModel):
method __init__ (line 10) | def __init__(self, args, device):
method _compute_kl_grad (line 56) | def _compute_kl_grad(
method compute_distribution_matching_loss (line 130) | def compute_distribution_matching_loss(
method generator_loss (line 199) | def generator_loss(
method critic_loss (line 240) | def critic_loss(
method _prepare_generator_input (line 338) | def _prepare_generator_input(self, ode_latent: torch.Tensor, tf=False,...
FILE: model/gan.py
class GAN (line 10) | class GAN(SelfForcingModel):
method __init__ (line 11) | def __init__(self, args, device):
method _run_cls_pred_branch (line 69) | def _run_cls_pred_branch(self,
method generator_loss (line 90) | def generator_loss(
method critic_loss (line 174) | def critic_loss(
FILE: model/naive_consistency.py
class NaiveConsistency (line 9) | class NaiveConsistency(BaseModel):
method __init__ (line 10) | def __init__(self, args, device):
method _initialize_models (line 66) | def _initialize_models(self, args, device):
method generator_loss (line 85) | def generator_loss(
FILE: model/ode_regression.py
class ODERegression (line 9) | class ODERegression(BaseModel):
method __init__ (line 10) | def __init__(self, args, device):
method _initialize_models (line 45) | def _initialize_models(self, args, device):
method _prepare_generator_input (line 59) | def _prepare_generator_input(self, ode_latent: torch.Tensor, tf=False,...
method generator_loss (line 95) | def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: d...
FILE: model/sid.py
class SiD (line 8) | class SiD(SelfForcingModel):
method __init__ (line 9) | def __init__(self, args, device):
method compute_distribution_matching_loss (line 47) | def compute_distribution_matching_loss(
method generator_loss (line 147) | def generator_loss(
method critic_loss (line 188) | def critic_loss(
FILE: pipeline/bidirectional_diffusion_inference.py
class BidirectionalDiffusionInferencePipeline (line 10) | class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 34) | def inference(
method _initialize_sample_scheduler (line 89) | def _initialize_sample_scheduler(self, noise):
FILE: pipeline/bidirectional_inference.py
class BidirectionalInferencePipeline (line 7) | class BidirectionalInferencePipeline(torch.nn.Module):
method __init__ (line 8) | def __init__(
method inference (line 33) | def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> t...
FILE: pipeline/bidirectional_training.py
class BidirectionalTrainingPipeline (line 8) | class BidirectionalTrainingPipeline:
method __init__ (line 9) | def __init__(self,
method generate_and_sync_list (line 44) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
method inference_with_trajectory (line 64) | def inference_with_trajectory(
FILE: pipeline/causal_diffusion_inference.py
class CausalDiffusionInferencePipeline (line 10) | class CausalDiffusionInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 51) | def inference(
method inference_for_cd (line 277) | def inference_for_cd(
method inference_for_genuine_cd (line 474) | def inference_for_genuine_cd(
method _initialize_kv_cache (line 605) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 635) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
method _initialize_sample_scheduler (line 656) | def _initialize_sample_scheduler(self, noise, sampling_steps=-1):
FILE: pipeline/causal_inference.py
class CausalInferencePipeline (line 10) | class CausalInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 65) | def inference(
method _initialize_kv_cache (line 337) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 359) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
FILE: pipeline/self_forcing_training.py
class SelfForcingTrainingPipeline (line 8) | class SelfForcingTrainingPipeline:
method __init__ (line 9) | def __init__(self,
method generate_and_sync_list (line 48) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
method inference_with_trajectory (line 68) | def inference_with_trajectory(
method _initialize_kv_cache (line 288) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 304) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
FILE: pipeline/teacher_forcing_training.py
class TeacherForcingTrainingPipeline (line 8) | class TeacherForcingTrainingPipeline:
method __init__ (line 9) | def __init__(self,
method generate_and_sync_list (line 44) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
method inference_with_trajectory (line 64) | def inference_with_trajectory(
FILE: train.py
function main (line 9) | def main():
FILE: trainer/diffusion.py
class Trainer (line 20) | class Trainer:
method __init__ (line 21) | def __init__(self, config):
method save (line 168) | def save(self):
method train_one_step (line 191) | def train_one_step(self, batch):
method train (line 261) | def train(self):
FILE: trainer/distillation.py
class Trainer (line 16) | class Trainer:
method __init__ (line 17) | def __init__(self, config):
method save (line 188) | def save(self):
method save_critic (line 212) | def save_critic(self):
method fwdbwd_one_step (line 229) | def fwdbwd_one_step(self, batch, train_generator, clean_latent=None):
method train (line 303) | def train(self):
FILE: trainer/gan.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, config):
method save (line 208) | def save(self):
method fwdbwd_one_step (line 235) | def fwdbwd_one_step(self, batch, train_generator):
method generate_video (line 324) | def generate_video(self, pipeline, prompts, image=None):
method train (line 337) | def train(self):
method all_gather_dict (line 457) | def all_gather_dict(self, target_dict):
FILE: trainer/naive_cd.py
class Trainer (line 20) | class Trainer:
method __init__ (line 21) | def __init__(self, config):
method save (line 183) | def save(self):
method fwdbwd_one_step (line 206) | def fwdbwd_one_step(self, batch, clean_latent=None):
method train (line 251) | def train(self):
FILE: trainer/ode.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, config):
method save (line 132) | def save(self):
method train_one_step (line 148) | def train_one_step(self, loss_scale=1.0):
method train (line 217) | def train(self):
FILE: utils/create_lmdb_iterative.py
function store_arrays_to_lmdb (line 10) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
function get_array_shape_from_lmdb (line 27) | def get_array_shape_from_lmdb(env, array_name):
function process_data_dict (line 34) | def process_data_dict(data_dict, seen_prompts):
function retrieve_row_from_lmdb (line 59) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...
function main (line 78) | def main():
FILE: utils/dataset.py
class TextDataset (line 12) | class TextDataset(Dataset):
method __init__ (line 13) | def __init__(self, prompt_path, extended_prompt_path=None):
method __len__ (line 24) | def __len__(self):
method __getitem__ (line 27) | def __getitem__(self, idx):
class ODERegressionLMDBDataset (line 37) | class ODERegressionLMDBDataset(Dataset):
method __init__ (line 38) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 45) | def __len__(self):
method __getitem__ (line 48) | def __getitem__(self, idx):
class LatentLMDBDataset (line 75) | class LatentLMDBDataset(Dataset):
method __init__ (line 76) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 83) | def __len__(self):
method __getitem__ (line 86) | def __getitem__(self, idx):
class ShardingLMDBDataset (line 110) | class ShardingLMDBDataset(Dataset):
method __init__ (line 111) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 134) | def __len__(self):
method __getitem__ (line 137) | def __getitem__(self, idx):
class TextImagePairDataset (line 166) | class TextImagePairDataset(Dataset):
method __init__ (line 167) | def __init__(
method __len__ (line 221) | def __len__(self):
method __getitem__ (line 224) | def __getitem__(self, idx):
function cycle (line 257) | def cycle(dl):
FILE: utils/distributed.py
function fsdp_state_dict (line 11) | def fsdp_state_dict(model):
function fsdp_wrap (line 23) | def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, w...
function barrier (line 70) | def barrier():
function launch_distributed_job (line 75) | def launch_distributed_job(backend: str = "nccl"):
class EMA_FSDP (line 91) | class EMA_FSDP:
method __init__ (line 92) | def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
method _init_shadow (line 98) | def _init_shadow(self, fsdp_module):
method update (line 103) | def update(self, fsdp_module):
method state_dict (line 109) | def state_dict(self):
method load_state_dict (line 112) | def load_state_dict(self, sd):
method copy_to (line 115) | def copy_to(self, fsdp_module):
method full_state_dict (line 121) | def full_state_dict(self, fsdp_module):
FILE: utils/lmdb_.py
function get_array_shape_from_lmdb (line 4) | def get_array_shape_from_lmdb(env, array_name):
function store_arrays_to_lmdb (line 11) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
function process_data_dict (line 30) | def process_data_dict(data_dict, seen_prompts):
function retrieve_row_from_lmdb (line 56) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...
FILE: utils/loss.py
class DenoisingLoss (line 5) | class DenoisingLoss(ABC):
method __call__ (line 7) | def __call__(
class X0PredLoss (line 27) | class X0PredLoss(DenoisingLoss):
method __call__ (line 28) | def __call__(
class VPredLoss (line 38) | class VPredLoss(DenoisingLoss):
method __call__ (line 39) | def __call__(
class NoisePredLoss (line 50) | class NoisePredLoss(DenoisingLoss):
method __call__ (line 51) | def __call__(
class FlowPredLoss (line 61) | class FlowPredLoss(DenoisingLoss):
method __call__ (line 62) | def __call__(
function get_denoising_loss (line 80) | def get_denoising_loss(loss_type: str) -> DenoisingLoss:
FILE: utils/merge_and_get_clean.py
function read_shape (line 8) | def read_shape(env, name):
function list_array_names (line 14) | def list_array_names(env):
function ensure_empty_dir (line 24) | def ensure_empty_dir(path):
function safe_mapsize (line 28) | def safe_mapsize(env):
function get_bytes (line 32) | def get_bytes(txn, key):
function latents_bytes_to_out (line 37) | def latents_bytes_to_out(row_bytes, in_row_shape, out_row_shape):
function merge_many (line 50) | def merge_many(src_dirs_all, dst_dir):
function rm_dirs (line 140) | def rm_dirs(dirs, desc="remove dirs"):
FILE: utils/merge_lmdb.py
function read_shape (line 8) | def read_shape(env, name):
function list_array_names (line 15) | def list_array_names(env):
function ensure_empty_dir (line 26) | def ensure_empty_dir(path):
function safe_mapsize (line 31) | def safe_mapsize(env):
function get_bytes (line 35) | def get_bytes(txn, key):
function latents_bytes_to_out (line 41) | def latents_bytes_to_out(row_bytes, in_row_shape, out_row_shape):
function merge_many (line 47) | def merge_many(src_dirs_all, dst_dir):
function rm_dirs (line 142) | def rm_dirs(dirs, desc="remove dirs"):
FILE: utils/misc.py
function set_seed (line 6) | def set_seed(seed: int, deterministic: bool = False):
function merge_dict_list (line 25) | def merge_dict_list(dict_list):
FILE: utils/ode_generation.py
function merge_cfg_prompt_embeds (line 6) | def merge_cfg_prompt_embeds(
function normalize_trajectory_indices (line 21) | def normalize_trajectory_indices(
class CausalODETrajectoryGenerator (line 37) | class CausalODETrajectoryGenerator:
method __init__ (line 38) | def __init__(
method _make_kv_cache (line 55) | def _make_kv_cache(self, batch_size: int, device: torch.device) -> lis...
method _make_crossattn_cache (line 81) | def _make_crossattn_cache(self, batch_size: int, device: torch.device)...
method _batched_cfg_step (line 101) | def _batched_cfg_step(
method _update_clean_cache (line 132) | def _update_clean_cache(
method _generate_full (line 156) | def _generate_full(
method _generate_blockwise_kv (line 196) | def _generate_blockwise_kv(
method _assemble_selected_trajectory (line 267) | def _assemble_selected_trajectory(
method generate (line 290) | def generate(
FILE: utils/scheduler.py
class SchedulerInterface (line 5) | class SchedulerInterface(ABC):
method add_noise (line 12) | def add_noise(
method convert_x0_to_noise (line 26) | def convert_x0_to_noise(
method convert_noise_to_x0 (line 52) | def convert_noise_to_x0(
method convert_velocity_to_x0 (line 77) | def convert_velocity_to_x0(
class FlowMatchScheduler (line 106) | class FlowMatchScheduler():
method __init__ (line 108) | def __init__(self, num_inference_steps=100, num_train_timesteps=1000, ...
method set_timesteps (line 118) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
method step (line 143) | def step(self, model_output, timestep, sample, to_final=False):
method add_noise (line 159) | def add_noise(self, original_samples, noise, timestep):
method training_target (line 178) | def training_target(self, sample, noise, timestep):
method training_weight (line 182) | def training_weight(self, timestep):
FILE: utils/wan_wrapper.py
class WanTextEncoder (line 14) | class WanTextEncoder(torch.nn.Module):
method __init__ (line 15) | def __init__(self) -> None:
method device (line 33) | def device(self):
method forward (line 37) | def forward(self, text_prompts: List[str]) -> dict:
class WanVAEWrapper (line 53) | class WanVAEWrapper(torch.nn.Module):
method __init__ (line 54) | def __init__(self):
method encode_to_latent (line 73) | def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
method decode_to_pixel (line 89) | def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = Fals...
class WanDiffusionWrapper (line 115) | class WanDiffusionWrapper(torch.nn.Module):
method __init__ (line 116) | def __init__(
method enable_gradient_checkpointing (line 144) | def enable_gradient_checkpointing(self) -> None:
method adding_cls_branch (line 147) | def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_di...
method _convert_flow_pred_to_x0 (line 169) | def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch....
method _convert_x0_to_flow_pred (line 196) | def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: tor...
method forward (line 218) | def forward(
method get_scheduler (line 294) | def get_scheduler(self) -> SchedulerInterface:
method post_init (line 308) | def post_init(self):
FILE: wan/distributed/fsdp.py
function shard_model (line 10) | def shard_model(
FILE: wan/distributed/xdit_context_parallel.py
function pad_freqs (line 12) | def pad_freqs(original_tensor, target_len):
function rope_apply (line 26) | def rope_apply(x, grid_sizes, freqs):
function usp_dit_forward (line 66) | def usp_dit_forward(
function usp_attn_forward (line 149) | def usp_attn_forward(self,
FILE: wan/image2video.py
class WanI2V (line 29) | class WanI2V:
method __init__ (line 31) | def __init__(
method generate (line 129) | def generate(self,
FILE: wan/modules/attention.py
function is_hopper_gpu (line 7) | def is_hopper_gpu():
function flash_attention (line 32) | def flash_attention(
function attention (line 139) | def attention(
FILE: wan/modules/causal_model.py
function causal_rope_apply (line 31) | def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
class CausalWanSelfAttention (line 62) | class CausalWanSelfAttention(nn.Module):
method __init__ (line 64) | def __init__(self,
method forward (line 90) | def forward(
class CausalWanAttentionBlock (line 248) | class CausalWanAttentionBlock(nn.Module):
method __init__ (line 250) | def __init__(self,
method forward (line 288) | def forward(
class CausalHead (line 343) | class CausalHead(nn.Module):
method __init__ (line 345) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 360) | def forward(self, x, e):
class CausalWanModel (line 374) | class CausalWanModel(ModelMixin, ConfigMixin):
method __init__ (line 386) | def __init__(self,
method _set_gradient_checkpointing (line 507) | def _set_gradient_checkpointing(self, module, value=False):
method _prepare_blockwise_causal_attn_mask (line 511) | def _prepare_blockwise_causal_attn_mask(
method _prepare_teacher_forcing_mask (line 569) | def _prepare_teacher_forcing_mask(
method _prepare_blockwise_causal_attn_mask_i2v (line 657) | def _prepare_blockwise_causal_attn_mask_i2v(
method _forward_inference (line 717) | def _forward_inference(
method _forward_train (line 848) | def _forward_train(
method forward (line 1007) | def forward(
method unpatchify (line 1018) | def unpatchify(self, x, grid_sizes):
method init_weights (line 1043) | def init_weights(self):
FILE: wan/modules/clip.py
function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
class QuickGELU (line 41) | class QuickGELU(nn.Module):
method forward (line 43) | def forward(self, x):
class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
method forward (line 49) | def forward(self, x):
class SelfAttention (line 53) | class SelfAttention(nn.Module):
method __init__ (line 55) | def __init__(self,
method forward (line 74) | def forward(self, x):
class SwiGLU (line 94) | class SwiGLU(nn.Module):
method __init__ (line 96) | def __init__(self, dim, mid_dim):
method forward (line 106) | def forward(self, x):
class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
method __init__ (line 114) | def __init__(self,
method forward (line 146) | def forward(self, x):
class AttentionPool (line 156) | class AttentionPool(nn.Module):
method __init__ (line 158) | def __init__(self,
method forward (line 186) | def forward(self, x):
class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
method __init__ (line 211) | def __init__(self,
method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
method __init__ (line 305) | def __init__(self, **kwargs):
method forward (line 315) | def forward(self, ids):
class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
method __init__ (line 330) | def __init__(self,
method forward (line 406) | def forward(self, imgs, txt_ids):
method param_groups (line 418) | def param_groups(self):
function _clip (line 434) | def _clip(pretrained=False,
function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
class CLIPModel (line 501) | class CLIPModel:
method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
method visual (line 527) | def visual(self, videos):
FILE: wan/modules/model.py
function sinusoidal_embedding_1d (line 15) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 29) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 40) | def rope_apply(x, grid_sizes, freqs):
class WanRMSNorm (line 70) | class WanRMSNorm(nn.Module):
method __init__ (line 72) | def __init__(self, dim, eps=1e-5):
method forward (line 78) | def forward(self, x):
method _norm (line 85) | def _norm(self, x):
class WanLayerNorm (line 89) | class WanLayerNorm(nn.LayerNorm):
method __init__ (line 91) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 94) | def forward(self, x):
class WanSelfAttention (line 102) | class WanSelfAttention(nn.Module):
method __init__ (line 104) | def __init__(self,
method forward (line 127) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanT2VCrossAttention (line 159) | class WanT2VCrossAttention(WanSelfAttention):
method forward (line 161) | def forward(self, x, context, context_lens, crossattn_cache=None):
class WanGanCrossAttention (line 197) | class WanGanCrossAttention(WanSelfAttention):
method forward (line 199) | def forward(self, x, context, crossattn_cache=None):
class WanI2VCrossAttention (line 224) | class WanI2VCrossAttention(WanSelfAttention):
method __init__ (line 226) | def __init__(self,
method forward (line 240) | def forward(self, x, context, context_lens):
class WanAttentionBlock (line 275) | class WanAttentionBlock(nn.Module):
method __init__ (line 277) | def __init__(self,
method forward (line 315) | def forward(
class GanAttentionBlock (line 357) | class GanAttentionBlock(nn.Module):
method __init__ (line 359) | def __init__(self,
method forward (line 397) | def forward(
class Head (line 439) | class Head(nn.Module):
method __init__ (line 441) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 456) | def forward(self, x, e):
class MLPProj (line 469) | class MLPProj(torch.nn.Module):
method __init__ (line 471) | def __init__(self, in_dim, out_dim):
method forward (line 479) | def forward(self, image_embeds):
class RegisterTokens (line 484) | class RegisterTokens(nn.Module):
method __init__ (line 485) | def __init__(self, num_registers: int, dim: int):
method forward (line 490) | def forward(self):
method reset_parameters (line 493) | def reset_parameters(self):
class WanModel (line 497) | class WanModel(ModelMixin, ConfigMixin):
method __init__ (line 509) | def __init__(self,
method _set_gradient_checkpointing (line 623) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 626) | def forward(
method _forward (line 637) | def _forward(
method _forward_classify (line 773) | def _forward_classify(
method unpatchify (line 876) | def unpatchify(self, x, grid_sizes, c=None):
method init_weights (line 901) | def init_weights(self):
FILE: wan/modules/t5.py
function fp16_clamp (line 20) | def fp16_clamp(x):
function init_weights (line 27) | def init_weights(m):
class GELU (line 46) | class GELU(nn.Module):
method forward (line 48) | def forward(self, x):
class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
method forward (line 61) | def forward(self, x):
class T5Attention (line 69) | class T5Attention(nn.Module):
method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
method forward (line 136) | def forward(self, x):
class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
method __init__ (line 146) | def __init__(self,
method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
method __init__ (line 180) | def __init__(self,
method forward (line 206) | def forward(self,
class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
method forward (line 233) | def forward(self, lq, lk):
method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
class T5Encoder (line 267) | class T5Encoder(nn.Module):
method __init__ (line 269) | def __init__(self,
method forward (line 303) | def forward(self, ids, mask=None):
class T5Decoder (line 315) | class T5Decoder(nn.Module):
method __init__ (line 317) | def __init__(self,
method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
class T5Model (line 372) | class T5Model(nn.Module):
method __init__ (line 374) | def __init__(self,
method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
function _t5 (line 415) | def _t5(name,
function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
class T5EncoderModel (line 472) | class T5EncoderModel:
method __init__ (line 474) | def __init__(
method __call__ (line 506) | def __call__(self, texts, device):
FILE: wan/modules/tokenizers.py
function basic_clean (line 12) | def basic_clean(text):
function whitespace_clean (line 18) | def whitespace_clean(text):
function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
method __call__ (line 49) | def __call__(self, sequence, **kwargs):
method _clean (line 75) | def _clean(self, text):
FILE: wan/modules/vae.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 28) | def forward(self, x, cache_x=None):
class RMS_norm (line 39) | class RMS_norm(nn.Module):
method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 51) | def forward(self, x):
class Upsample (line 57) | class Upsample(nn.Upsample):
method forward (line 59) | def forward(self, x):
class Resample (line 66) | class Resample(nn.Module):
method __init__ (line 68) | def __init__(self, dim, mode):
method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 162) | def init_weight(self, conv):
method init_weight2 (line 174) | def init_weight2(self, conv):
class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
method __init__ (line 228) | def __init__(self, dim):
method forward (line 240) | def forward(self, x):
class Encoder3d (line 265) | class Encoder3d(nn.Module):
method __init__ (line 267) | def __init__(self,
method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 369) | class Decoder3d(nn.Module):
method __init__ (line 371) | def __init__(self,
method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
function count_conv3d (line 475) | def count_conv3d(model):
class WanVAE_ (line 483) | class WanVAE_(nn.Module):
method __init__ (line 485) | def __init__(self,
method forward (line 511) | def forward(self, x):
method encode (line 517) | def encode(self, x, scale):
method decode (line 545) | def decode(self, z, scale):
method cached_decode (line 571) | def cached_decode(self, z, scale):
method sample (line 595) | def sample(self, imgs, deterministic=False):
method clear_cache (line 602) | def clear_cache(self):
function _video_vae (line 612) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
class WanVAE (line 639) | class WanVAE:
method __init__ (line 641) | def __init__(self,
method encode (line 667) | def encode(self, videos):
method decode (line 677) | def decode(self, zs):
FILE: wan/modules/xlm_roberta.py
class SelfAttention (line 10) | class SelfAttention(nn.Module):
method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
method forward (line 27) | def forward(self, x, mask):
class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
method forward (line 66) | def forward(self, x, mask):
class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
method __init__ (line 81) | def __init__(self,
method forward (line 118) | def forward(self, ids):
function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,
FILE: wan/text2video.py
class WanT2V (line 26) | class WanT2V:
method __init__ (line 28) | def __init__(
method generate (line 110) | def generate(self,
FILE: wan/utils/fm_solvers.py
function get_sampling_sigmas (line 22) | def get_sampling_sigmas(sampling_steps, shift):
function retrieve_timesteps (line 29) | def retrieve_timesteps(
class FlowDPMSolverMultistepScheduler (line 69) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 129) | def __init__(
method step_index (line 202) | def step_index(self):
method begin_index (line 209) | def begin_index(self):
method set_begin_index (line 216) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 226) | def set_timesteps(
method _threshold_sample (line 292) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 330) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 333) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 337) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 341) | def convert_model_output(
method dpm_solver_first_order_update (line 415) | def dpm_solver_first_order_update(
method multistep_dpm_solver_second_order_update (line 486) | def multistep_dpm_solver_second_order_update(
method multistep_dpm_solver_third_order_update (line 596) | def multistep_dpm_solver_third_order_update(
method index_for_timestep (line 679) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 693) | def _init_step_index(self, timestep):
method step (line 706) | def step(
method scale_model_input (line 800) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 815) | def add_noise(
method __len__ (line 856) | def __len__(self):
FILE: wan/utils/fm_solvers_unipc.py
class FlowUniPCMultistepScheduler (line 20) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 77) | def __init__(
method step_index (line 135) | def step_index(self):
method begin_index (line 142) | def begin_index(self):
method set_begin_index (line 149) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 160) | def set_timesteps(
method _threshold_sample (line 230) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 269) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 272) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 276) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 279) | def convert_model_output(
method multistep_uni_p_bh_update (line 350) | def multistep_uni_p_bh_update(
method multistep_uni_c_bh_update (line 486) | def multistep_uni_c_bh_update(
method index_for_timestep (line 628) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 643) | def _init_step_index(self, timestep):
method step (line 655) | def step(self,
method scale_model_input (line 741) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 758) | def add_noise(
method __len__ (line 799) | def __len__(self):
FILE: wan/utils/prompt_extend.py
class PromptOutput (line 101) | class PromptOutput(object):
method add_custom_field (line 108) | def add_custom_field(self, key: str, value) -> None:
class PromptExpander (line 112) | class PromptExpander:
method __init__ (line 114) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
method extend_with_img (line 119) | def extend_with_img(self,
method extend (line 128) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method decide_system_prompt (line 131) | def decide_system_prompt(self, tar_lang="ch"):
method __call__ (line 138) | def __call__(self,
class DashScopePromptExpander (line 157) | class DashScopePromptExpander(PromptExpander):
method __init__ (line 159) | def __init__(self,
method extend (line 196) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 232) | def extend_with_img(self,
class QwenPromptExpander (line 300) | class QwenPromptExpander(PromptExpander):
method __init__ (line 309) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
method extend (line 366) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 397) | def extend_with_img(self,
FILE: wan/utils/qwen_vl_utils.py
function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
function smart_resize (line 54) | def smart_resize(height: int,
function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
function smart_nframes (line 133) | def smart_nframes(
function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
function is_decord_available (line 215) | def is_decord_available() -> bool:
function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
function fetch_video (line 274) | def fetch_video(
function extract_vision_info (line 328) | def extract_vision_info(
function process_vision_info (line 344) | def process_vision_info(
FILE: wan/utils/utils.py
function rand_name (line 14) | def rand_name(length=8, suffix=''):
function cache_video (line 23) | def cache_video(tensor,
function cache_image (line 64) | def cache_image(tensor,
function str2bool (line 94) | def str2bool(v):
Condensed preview — 154 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,501K chars).
[
{
"path": ".gitignore",
"chars": 104,
"preview": "__pycache__\n*.egg-info\n\nwan_models\ncheckpoints\noutput\ndataset\nprompts/vidprom_filtered_extended.txt\nlogs"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 21236,
"preview": "<div align=\"center\">\n\n## Causal Forcing & Causal Forcing++\n### Autoregressive Diffusion Distillation Done Right for High"
},
{
"path": "configs/ar_diffusion_tf_chunkwise.yaml",
"chars": 1113,
"preview": "generator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: size\nfake_score_fsdp_wrap_strategy: size\nreal_name: Wa"
},
{
"path": "configs/ar_diffusion_tf_framewise.yaml",
"chars": 1113,
"preview": "generator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: size\nfake_score_fsdp_wrap_strategy: size\nreal_name: Wa"
},
{
"path": "configs/causal_cd_chunkwise.yaml",
"chars": 1364,
"preview": "generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: "
},
{
"path": "configs/causal_cd_framewise.yaml",
"chars": 1364,
"preview": "generator_ckpt: checkpoints/framewise/ar_diffusion.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: "
},
{
"path": "configs/causal_forcing_dmd_chunkwise.yaml",
"chars": 1315,
"preview": "generator_ckpt: checkpoints/chunkwise/causal_ode.pt # 🔥 or checkpoints/chunkwise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
},
{
"path": "configs/causal_forcing_dmd_framewise.yaml",
"chars": 1314,
"preview": "generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
},
{
"path": "configs/causal_forcing_dmd_framewise_1step.yaml",
"chars": 1522,
"preview": "generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
},
{
"path": "configs/causal_forcing_dmd_framewise_2step.yaml",
"chars": 1528,
"preview": "generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
},
{
"path": "configs/causal_ode_chunkwise.yaml",
"chars": 655,
"preview": "generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt\ngenerator_grad:\n model: true\ndenoising_step_list:\n- 1000\n- 750\n- "
},
{
"path": "configs/causal_ode_framewise.yaml",
"chars": 657,
"preview": "generator_ckpt: checkpoints/framewise/ar_diffusion.pt\ngenerator_grad:\n model: true\ndenoising_step_list:\n- 1000\n- 750\n- "
},
{
"path": "configs/default_config.yaml",
"chars": 403,
"preview": "independent_first_frame: false\nwarp_denoising_step: false\nweight_decay: 0.01\nsame_step_across_blocks: true\ndiscriminator"
},
{
"path": "demo.py",
"chars": 24414,
"preview": "\"\"\"\nDemo for Causal-Forcing.\n\"\"\"\n\nimport os\nimport re\nimport random\nimport time\nimport base64\nimport argparse\nimport has"
},
{
"path": "demo_utils/constant.py",
"chars": 1352,
"preview": "\nimport torch\n\n\nZERO_VAE_CACHE = [\n torch.zeros(1, 16, 2, 60, 104),\n torch.zeros(1, 384, 2, 60, 104),\n torch.ze"
},
{
"path": "demo_utils/memory.py",
"chars": 4417,
"preview": "# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils\n# Apache-2.0 License\n# By lllyasviel\n\nimport "
},
{
"path": "demo_utils/taehv.py",
"chars": 14157,
"preview": "#!/usr/bin/env python3\n\"\"\"\nTiny AutoEncoder for Hunyuan Video\n(DNN for encoding / decoding videos to Hunyuan Video's lat"
},
{
"path": "demo_utils/utils.py",
"chars": 17547,
"preview": "# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils\n# Apache-2.0 License\n# By lllyasviel\n\nimport "
},
{
"path": "demo_utils/vae.py",
"chars": 15414,
"preview": "from typing import List\nfrom einops import rearrange\nimport tensorrt as trt\nimport torch\nimport torch.nn as nn\n\nfrom dem"
},
{
"path": "demo_utils/vae_block3.py",
"chars": 11058,
"preview": "from typing import List\nfrom einops import rearrange\nimport torch\nimport torch.nn as nn\n\nfrom wan.modules.vae import Att"
},
{
"path": "demo_utils/vae_torch2trt.py",
"chars": 11889,
"preview": "# ---- INT8 (optional) ----\nfrom demo_utils.vae import (\n VAEDecoderWrapperSingle, # main nn."
},
{
"path": "get_causal_ode_data_chunkwise.py",
"chars": 4371,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\nfrom utils.scheduler import FlowMatchSc"
},
{
"path": "get_causal_ode_data_framewise.py",
"chars": 4371,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\nfrom utils.scheduler import FlowMatchSc"
},
{
"path": "get_causal_ode_data_kv_optimized.py",
"chars": 5832,
"preview": "import argparse\nimport math\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom tqdm import tqdm\n\nfrom utils.d"
},
{
"path": "inference.py",
"chars": 8317,
"preview": "import argparse\nimport argparse\nimport torch\nimport os\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom torchv"
},
{
"path": "long_video/LICENSE",
"chars": 12242,
"preview": "Tencent is pleased to support the community by making RollingForcing available.\n\nCopyright (C) 2025 Tencent. All right"
},
{
"path": "long_video/README.md",
"chars": 1279,
"preview": "Builing on [Rolling Forcing](https://github.com/TencentARC/RollingForcing), we implemented minute-level long video gener"
},
{
"path": "long_video/app.py",
"chars": 6831,
"preview": "import os\nimport argparse\nimport time\nfrom typing import Optional\n\nimport torch\nfrom torchvision.io import write_video\nf"
},
{
"path": "long_video/configs/default_config.yaml",
"chars": 403,
"preview": "independent_first_frame: false\nwarp_denoising_step: false\nweight_decay: 0.01\nsame_step_across_blocks: true\ndiscriminator"
},
{
"path": "long_video/configs/rolling_forcing_dmd.yaml",
"chars": 1168,
"preview": "generator_ckpt: ../checkpoints/chunkwise/causal_ode.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy:"
},
{
"path": "long_video/inference.py",
"chars": 7755,
"preview": "import argparse\nimport torch\nimport os\nfrom omegaconf import OmegaConf\nfrom collections import OrderedDict\nfrom tqdm imp"
},
{
"path": "long_video/model/__init__.py",
"chars": 278,
"preview": "from .diffusion import CausalDiffusion\nfrom .causvid import CausVid\nfrom .dmd import DMD\nfrom .gan import GAN\nfrom .sid "
},
{
"path": "long_video/model/base.py",
"chars": 11031,
"preview": "from typing import Tuple\nfrom einops import rearrange\nfrom torch import nn\nimport torch.distributed as dist\nimport torch"
},
{
"path": "long_video/model/causvid.py",
"chars": 17252,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\n\n\nclass CausVid("
},
{
"path": "long_video/model/diffusion.py",
"chars": 5641,
"preview": "from typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_wrapper import WanDiffusionWrappe"
},
{
"path": "long_video/model/dmd.py",
"chars": 15496,
"preview": "from pipeline import RollingForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\ni"
},
{
"path": "long_video/model/gan.py",
"chars": 14244,
"preview": "import copy\nfrom pipeline import RollingForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Tuple"
},
{
"path": "long_video/model/ode_regression.py",
"chars": 5763,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_w"
},
{
"path": "long_video/model/sid.py",
"chars": 12650,
"preview": "from pipeline import RollingForcingTrainingPipeline\nfrom typing import Optional, Tuple\nimport torch\n\nfrom model.base imp"
},
{
"path": "long_video/pipeline/__init__.py",
"chars": 568,
"preview": "from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline\nfrom .bidirectional_inference imp"
},
{
"path": "long_video/pipeline/bidirectional_diffusion_inference.py",
"chars": 4146,
"preview": "from tqdm import tqdm\nfrom typing import List\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMultistepSched"
},
{
"path": "long_video/pipeline/bidirectional_inference.py",
"chars": 3109,
"preview": "from typing import List\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\n\n"
},
{
"path": "long_video/pipeline/causal_diffusion_inference.py",
"chars": 16465,
"preview": "from tqdm import tqdm\nfrom typing import List, Optional\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMult"
},
{
"path": "long_video/pipeline/rolling_forcing_inference.py",
"chars": 17659,
"preview": "from typing import List, Optional\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVA"
},
{
"path": "long_video/pipeline/rolling_forcing_training.py",
"chars": 22725,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
},
{
"path": "long_video/prompts/example_prompts.txt",
"chars": 6517,
"preview": "A cinematic scene from a classic western movie, featuring a rugged man riding a powerful horse through the vast Gobi Des"
},
{
"path": "long_video/requirements.txt",
"chars": 551,
"preview": "torch==2.5.1\ntorchvision==0.20.1\ntorchaudio==2.5.1\nopencv-python>=4.9.0.80\ndiffusers==0.31.0\ntransformers>=4.49.0\ntokeni"
},
{
"path": "long_video/train.py",
"chars": 1612,
"preview": "import argparse\nimport os\nfrom omegaconf import OmegaConf\n\nfrom trainer import DiffusionTrainer, GANTrainer, ODETrainer,"
},
{
"path": "long_video/trainer/__init__.py",
"chars": 297,
"preview": "from .diffusion import Trainer as DiffusionTrainer\nfrom .gan import Trainer as GANTrainer\nfrom .ode import Trainer as OD"
},
{
"path": "long_video/trainer/diffusion.py",
"chars": 10384,
"preview": "import gc\nimport logging\n\nfrom model import CausalDiffusion\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom ut"
},
{
"path": "long_video/trainer/distillation.py",
"chars": 15969,
"preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.dataset import TextDataset\nfro"
},
{
"path": "long_video/trainer/gan.py",
"chars": 19999,
"preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.distributed import EMA_FSDP, f"
},
{
"path": "long_video/trainer/ode.py",
"chars": 9791,
"preview": "import gc\nimport logging\nfrom utils.dataset import ODERegressionLMDBDataset, cycle\nfrom model import ODERegression\nfrom "
},
{
"path": "long_video/utils/dataset.py",
"chars": 7346,
"preview": "from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb\nfrom torch.utils.data import Dataset\nimport num"
},
{
"path": "long_video/utils/distributed.py",
"chars": 4571,
"preview": "from datetime import timedelta\nfrom functools import partial\nimport os\nimport torch\nimport torch.distributed as dist\nfro"
},
{
"path": "long_video/utils/lmdb.py",
"chars": 2045,
"preview": "import numpy as np\n\n\ndef get_array_shape_from_lmdb(env, array_name):\n with env.begin() as txn:\n image_shape = "
},
{
"path": "long_video/utils/loss.py",
"chars": 2467,
"preview": "from abc import ABC, abstractmethod\nimport torch\n\n\nclass DenoisingLoss(ABC):\n @abstractmethod\n def __call__(\n "
},
{
"path": "long_video/utils/misc.py",
"chars": 1155,
"preview": "import numpy as np\nimport random\nimport torch\n\n\ndef set_seed(seed: int, deterministic: bool = False):\n \"\"\"\n Helper"
},
{
"path": "long_video/utils/scheduler.py",
"chars": 7979,
"preview": "from abc import abstractmethod, ABC\nimport torch\n\n\nclass SchedulerInterface(ABC):\n \"\"\"\n Base class for diffusion n"
},
{
"path": "long_video/utils/wan_wrapper.py",
"chars": 12465,
"preview": "import types\nfrom typing import List, Optional\nimport torch\nfrom torch import nn\n\nfrom utils.scheduler import SchedulerI"
},
{
"path": "long_video/wan/README.md",
"chars": 92,
"preview": "Code in this folder is modified from https://github.com/Wan-Video/Wan2.1\nApache-2.0 License "
},
{
"path": "long_video/wan/__init__.py",
"chars": 107,
"preview": "from . import configs, distributed, modules\nfrom .image2video import WanI2V\nfrom .text2video import WanT2V\n"
},
{
"path": "long_video/wan/configs/__init__.py",
"chars": 1011,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .wan_t2v_14B import t2v_14B\nfrom .wan_t2v_"
},
{
"path": "long_video/wan/configs/shared_config.py",
"chars": 650,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n# -"
},
{
"path": "long_video/wan/configs/wan_i2v_14B.py",
"chars": 972,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
},
{
"path": "long_video/wan/configs/wan_t2v_14B.py",
"chars": 743,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "long_video/wan/configs/wan_t2v_1_3B.py",
"chars": 760,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "long_video/wan/distributed/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "long_video/wan/distributed/fsdp.py",
"chars": 1077,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom functools import partial\n\nimport torch\nfro"
},
{
"path": "long_video/wan/distributed/xdit_context_parallel.py",
"chars": 5899,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
},
{
"path": "long_video/wan/image2video.py",
"chars": 13203,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "long_video/wan/modules/__init__.py",
"chars": 365,
"preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
},
{
"path": "long_video/wan/modules/attention.py",
"chars": 5641,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n import flash_attn_interf"
},
{
"path": "long_video/wan/modules/causal_model.py",
"chars": 45180,
"preview": "from wan.modules.attention import attention\nfrom wan.modules.model import (\n WanRMSNorm,\n rope_apply,\n WanLayer"
},
{
"path": "long_video/wan/modules/clip.py",
"chars": 16835,
"preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
},
{
"path": "long_video/wan/modules/model.py",
"chars": 30757,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn"
},
{
"path": "long_video/wan/modules/t5.py",
"chars": 16910,
"preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
},
{
"path": "long_video/wan/modules/tokenizers.py",
"chars": 2431,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
},
{
"path": "long_video/wan/modules/vae.py",
"chars": 23735,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
},
{
"path": "long_video/wan/modules/xlm_roberta.py",
"chars": 4865,
"preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
},
{
"path": "long_video/wan/text2video.py",
"chars": 10241,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "long_video/wan/utils/__init__.py",
"chars": 339,
"preview": "from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,\n retrieve_timeste"
},
{
"path": "long_video/wan/utils/fm_solvers.py",
"chars": 40232,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
},
{
"path": "long_video/wan/utils/fm_solvers_unipc.py",
"chars": 32645,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
},
{
"path": "long_video/wan/utils/prompt_extend.py",
"chars": 30344,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
},
{
"path": "long_video/wan/utils/qwen_vl_utils.py",
"chars": 13044,
"preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
},
{
"path": "long_video/wan/utils/utils.py",
"chars": 3239,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
},
{
"path": "model/__init__.py",
"chars": 351,
"preview": "from .diffusion import CausalDiffusion\nfrom .causvid import CausVid\nfrom .dmd import DMD\nfrom .gan import GAN\nfrom .sid "
},
{
"path": "model/base.py",
"chars": 26975,
"preview": "from typing import Tuple\nfrom einops import rearrange\nfrom torch import nn\nimport torch.distributed as dist\nimport torch"
},
{
"path": "model/causvid.py",
"chars": 17342,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\n\n\nclass CausVid("
},
{
"path": "model/diffusion.py",
"chars": 5814,
"preview": "from typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_wrapper import WanDiffusionWrappe"
},
{
"path": "model/dmd.py",
"chars": 17367,
"preview": "from pipeline import SelfForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\nimpo"
},
{
"path": "model/gan.py",
"chars": 14232,
"preview": "import copy\nfrom pipeline import SelfForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Tuple\nim"
},
{
"path": "model/naive_consistency.py",
"chars": 6240,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\nimport random\nfrom model.base import BaseModel\nfro"
},
{
"path": "model/ode_regression.py",
"chars": 5763,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_w"
},
{
"path": "model/sid.py",
"chars": 12638,
"preview": "from pipeline import SelfForcingTrainingPipeline\nfrom typing import Optional, Tuple\nimport torch\n\nfrom model.base import"
},
{
"path": "pipeline/__init__.py",
"chars": 759,
"preview": "from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline\nfrom .bidirectional_inference imp"
},
{
"path": "pipeline/bidirectional_diffusion_inference.py",
"chars": 4146,
"preview": "from tqdm import tqdm\nfrom typing import List\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMultistepSched"
},
{
"path": "pipeline/bidirectional_inference.py",
"chars": 3094,
"preview": "from typing import List\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\n\n"
},
{
"path": "pipeline/bidirectional_training.py",
"chars": 7801,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
},
{
"path": "pipeline/causal_diffusion_inference.py",
"chars": 32056,
"preview": "from tqdm import tqdm\nfrom typing import List, Optional\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMult"
},
{
"path": "pipeline/causal_inference.py",
"chars": 17437,
"preview": "from typing import List, Optional\nimport time\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEn"
},
{
"path": "pipeline/self_forcing_training.py",
"chars": 16256,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
},
{
"path": "pipeline/teacher_forcing_training.py",
"chars": 10927,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
},
{
"path": "prompts/demos.txt",
"chars": 50302,
"preview": "Across a snowfield under pink dawn, a skier in a crimson down jacket, fur-lined hood, and mirrored goggles dominates the"
},
{
"path": "prompts/i2v/target_crop_info_26-15.json",
"chars": 870,
"preview": "[\n {\n \"file_name\": \"000001.png\",\n \"caption\": \"A cinematic closeup and detailed portrait of a reindeer standing in"
},
{
"path": "requirements.txt",
"chars": 445,
"preview": "torch>=2.4.0\ntorchvision>=0.19.0\nopencv-python>=4.9.0.80\ndiffusers==0.31.0\ntransformers>=4.49.0\ntokenizers>=0.20.3\naccel"
},
{
"path": "setup.py",
"chars": 131,
"preview": "from setuptools import setup, find_packages\nsetup(\n name=\"causal_forcing\",\n version=\"0.0.2\",\n packages=find_pac"
},
{
"path": "train.py",
"chars": 1768,
"preview": "import argparse\nimport os\nfrom omegaconf import OmegaConf\nimport wandb\n\nfrom trainer import DiffusionTrainer, ODETrainer"
},
{
"path": "trainer/__init__.py",
"chars": 399,
"preview": "from .diffusion import Trainer as DiffusionTrainer\nfrom .gan import Trainer as GANTrainer\nfrom .ode import Trainer as OD"
},
{
"path": "trainer/diffusion.py",
"chars": 11454,
"preview": "import gc\nimport logging\n\nfrom model import CausalDiffusion\nfrom utils.dataset import cycle, LatentLMDBDataset\nfrom util"
},
{
"path": "trainer/distillation.py",
"chars": 15443,
"preview": "import gc\nimport logging\nfrom utils.dataset import cycle\nfrom utils.dataset import TextDataset\nfrom utils.distributed im"
},
{
"path": "trainer/gan.py",
"chars": 20024,
"preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.distributed import EMA_FSDP, f"
},
{
"path": "trainer/naive_cd.py",
"chars": 11578,
"preview": "import gc\nimport logging\nfrom utils.dataset import cycle\nfrom utils.dataset import LatentLMDBDataset\nfrom utils.distribu"
},
{
"path": "trainer/ode.py",
"chars": 9204,
"preview": "import gc\nimport logging\nfrom utils.dataset import ODERegressionLMDBDataset, cycle\nfrom model import ODERegression\nfrom "
},
{
"path": "utils/create_lmdb_iterative.py",
"chars": 3693,
"preview": "from tqdm import tqdm\nimport numpy as np\nimport argparse\nimport torch\nimport lmdb\nimport glob\nimport os\n\n\ndef store_arra"
},
{
"path": "utils/dataset.py",
"chars": 8497,
"preview": "from utils.lmdb_ import get_array_shape_from_lmdb, retrieve_row_from_lmdb\nfrom torch.utils.data import Dataset\nimport nu"
},
{
"path": "utils/distributed.py",
"chars": 4994,
"preview": "from datetime import timedelta\nfrom functools import partial\nimport os\nimport torch\nimport torch.distributed as dist\nfro"
},
{
"path": "utils/lmdb_.py",
"chars": 2045,
"preview": "import numpy as np\n\n\ndef get_array_shape_from_lmdb(env, array_name):\n with env.begin() as txn:\n image_shape = "
},
{
"path": "utils/loss.py",
"chars": 2467,
"preview": "from abc import ABC, abstractmethod\nimport torch\n\n\nclass DenoisingLoss(ABC):\n @abstractmethod\n def __call__(\n "
},
{
"path": "utils/merge_and_get_clean.py",
"chars": 6069,
"preview": "import os, shutil, lmdb, numpy as np\nfrom tqdm import tqdm\n\nBASE = \"dataset\"\nBATCH = 512\nMAP_MULT = 2.2\n\ndef read_shape("
},
{
"path": "utils/merge_lmdb.py",
"chars": 5973,
"preview": "import os, shutil, lmdb, numpy as np\nfrom tqdm import tqdm\n\nBASE = \"dataset\"\nBATCH = 512\nMAP_MULT = 2.2\n\ndef read_shape("
},
{
"path": "utils/misc.py",
"chars": 1155,
"preview": "import numpy as np\nimport random\nimport torch\n\n\ndef set_seed(seed: int, deterministic: bool = False):\n \"\"\"\n Helper"
},
{
"path": "utils/ode_generation.py",
"chars": 11719,
"preview": "from typing import Dict, Iterable, Optional\n\nimport torch\n\n\ndef merge_cfg_prompt_embeds(\n conditional_dict: dict,\n "
},
{
"path": "utils/scheduler.py",
"chars": 7979,
"preview": "from abc import abstractmethod, ABC\nimport torch\n\n\nclass SchedulerInterface(ABC):\n \"\"\"\n Base class for diffusion n"
},
{
"path": "utils/wan_wrapper.py",
"chars": 12591,
"preview": "import types\nfrom typing import List, Optional\nimport torch\nfrom torch import nn\n\nfrom utils.scheduler import SchedulerI"
},
{
"path": "wan/README.md",
"chars": 92,
"preview": "Code in this folder is modified from https://github.com/Wan-Video/Wan2.1\nApache-2.0 License "
},
{
"path": "wan/__init__.py",
"chars": 107,
"preview": "from . import configs, distributed, modules\nfrom .image2video import WanI2V\nfrom .text2video import WanT2V\n"
},
{
"path": "wan/configs/__init__.py",
"chars": 1011,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .wan_t2v_14B import t2v_14B\nfrom .wan_t2v_"
},
{
"path": "wan/configs/shared_config.py",
"chars": 650,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n# -"
},
{
"path": "wan/configs/wan_i2v_14B.py",
"chars": 972,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
},
{
"path": "wan/configs/wan_t2v_14B.py",
"chars": 743,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/configs/wan_t2v_1_3B.py",
"chars": 760,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/distributed/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wan/distributed/fsdp.py",
"chars": 1077,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom functools import partial\n\nimport torch\nfro"
},
{
"path": "wan/distributed/xdit_context_parallel.py",
"chars": 5899,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
},
{
"path": "wan/image2video.py",
"chars": 13203,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/modules/__init__.py",
"chars": 365,
"preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
},
{
"path": "wan/modules/attention.py",
"chars": 5641,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n import flash_attn_interf"
},
{
"path": "wan/modules/causal_model.py",
"chars": 42290,
"preview": "from wan.modules.attention import attention\nfrom wan.modules.model import (\n WanRMSNorm,\n rope_apply,\n WanLayer"
},
{
"path": "wan/modules/clip.py",
"chars": 16835,
"preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
},
{
"path": "wan/modules/model.py",
"chars": 30757,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn"
},
{
"path": "wan/modules/t5.py",
"chars": 16910,
"preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
},
{
"path": "wan/modules/tokenizers.py",
"chars": 2431,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
},
{
"path": "wan/modules/vae.py",
"chars": 23735,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
},
{
"path": "wan/modules/xlm_roberta.py",
"chars": 4865,
"preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
},
{
"path": "wan/text2video.py",
"chars": 10241,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/utils/__init__.py",
"chars": 339,
"preview": "from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,\n retrieve_timeste"
},
{
"path": "wan/utils/fm_solvers.py",
"chars": 40232,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
},
{
"path": "wan/utils/fm_solvers_unipc.py",
"chars": 32645,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
},
{
"path": "wan/utils/prompt_extend.py",
"chars": 30344,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
},
{
"path": "wan/utils/qwen_vl_utils.py",
"chars": 13044,
"preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
},
{
"path": "wan/utils/utils.py",
"chars": 3239,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
}
]
About this extraction
This page contains the full source code of the thu-ml/Causal-Forcing GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 154 files (1.4 MB), approximately 339.4k tokens, and a symbol index with 1104 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.