Showing preview only (2,146K chars total). Download the full file or copy to clipboard to get everything.
Repository: GoatWu/Self-Forcing-Plus
Branch: main
Commit: 67b8016498b5
Files: 83
Total size: 2.0 MB
Directory structure:
gitextract_tdko9ms3/
├── LICENSE.md
├── README.md
├── configs/
│ ├── default_config.yaml
│ ├── self_forcing_14b_dmd.yaml
│ ├── self_forcing_14b_i2v_dmd.yaml
│ ├── self_forcing_dmd.yaml
│ └── self_forcing_sid.yaml
├── convert_checkpoint.py
├── demo.py
├── demo_utils/
│ ├── constant.py
│ ├── memory.py
│ ├── taehv.py
│ ├── utils.py
│ ├── vae.py
│ ├── vae_block3.py
│ └── vae_torch2trt.py
├── inference.py
├── model/
│ ├── __init__.py
│ ├── base.py
│ ├── causvid.py
│ ├── diffusion.py
│ ├── dmd.py
│ ├── gan.py
│ ├── ode_regression.py
│ └── sid.py
├── pipeline/
│ ├── __init__.py
│ ├── bidirectional_diffusion_inference.py
│ ├── bidirectional_inference.py
│ ├── bidirectional_training.py
│ ├── causal_diffusion_inference.py
│ ├── causal_inference.py
│ └── self_forcing_training.py
├── prompts/
│ ├── MovieGenVideoBench.txt
│ ├── MovieGenVideoBench_extended.txt
│ └── vbench/
│ ├── all_dimension.txt
│ └── all_dimension_longer.txt
├── requirements.txt
├── scripts/
│ ├── compute_vae_latent.py
│ ├── create_lmdb_14b_shards.py
│ ├── create_lmdb_iterative.py
│ └── generate_ode_pairs.py
├── setup.py
├── templates/
│ └── demo.html
├── train.py
├── trainer/
│ ├── __init__.py
│ ├── diffusion.py
│ ├── distillation.py
│ ├── gan.py
│ └── ode.py
├── utils/
│ ├── dataset.py
│ ├── distributed.py
│ ├── lmdb.py
│ ├── loss.py
│ ├── misc.py
│ ├── scheduler.py
│ └── wan_wrapper.py
└── wan/
├── README.md
├── __init__.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── causal_model.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae.py
│ └── xlm_roberta.py
├── text2video.py
└── utils/
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE.md
================================================
# Attribution-NonCommercial-ShareAlike 4.0 International
Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
### Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
### Section 1 – Definitions.
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
j. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
k. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
l. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
m. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
n. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
### Section 2 – Scope.
a. ___License grant.___
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
3. __Term.__ The term of this Public License is specified in Section 6(a).
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
5. __Downstream recipients.__
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
b. ___Other rights.___
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this Public License.
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
### Section 3 – License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
a. ___Attribution.___
1. If You Share the Licensed Material (including in modified form), You must:
A. retain the following if it is supplied by the Licensor with the Licensed Material:
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of warranties;
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
b. ___ShareAlike.___
In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
### Section 4 – Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
### Section 6 – Term and Termination.
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
### Section 7 – Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
### Section 8 – Interpretation.
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
>
> Creative Commons may be contacted at creativecommons.org
================================================
FILE: README.md
================================================
<p align="center">
<h1 align="center">Self Forcing Plus</h1>
Self-Forcing-Plus focuses on step distillation and CFG distillation for bidirectional models. Building upon Self-Forcing, we support 4-step T2V-14B model training and higher quality 4-step I2V-14B model training.
## 🔥 News
- (2025/09) Support Wan2.2-Moe distillation! [wan22](https://github.com/GoatWu/Self-Forcing-Plus/tree/wan22)
| Model Type | Model Link |
|------------|---------------|
| T2V-14B | [Huggingface](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill) |
| I2V-14B-480P | [Huggingface](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v) |
## Installation
Create a conda environment and install dependencies:
```
conda create -n self_forcing python=3.10 -y
conda activate self_forcing
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
python setup.py develop
```
## Quick Start
### Download checkpoints
```
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir wan_models/Wan2.1-T2V-14B
huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir wan_models/Wan2.1-I2V-14B-480P
```
## T2V Training
DMD training for bidirectional models do not need ODE initialization.
### DataSet Preparation
We build the dataset in the following way, each file contains a single prompt:
```
data_folder
|__1.txt
|__2.txt
...
|__xxx.txt
```
### DMD Training
```
torchrun --nnodes=8 --nproc_per_node=8 \
--rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
train.py \
--config_path configs/self_forcing_14b_dmd.yaml \
--logdir logs/self_forcing_14b_dmd \
--no_visualize \
--disable-wandb
```
Our training run uses 3000 iterations and completes in under 3 days using 64 H100 GPUs.
## I2V-480P Training
### DataSet Preparation
1. Generate a series of videos using the original Wan2.1 model.
2. Generate the VAE latents.
```bash
python scripts/compute_vae_latent.py \
--input_video_folder {video_folder} \
--output_latent_folder {latent_folder} \
--model_name Wan2.1-T2V-14B \
--prompt_folder {prompt_folder}
```
3. Separate the first frame of the videos and create an lmdb dataset.
```bash
python scripts/create_lmdb_14b_shards.py \
--data_path {latent_folder} \
--prompt_path {prompt_folder} \
--lmdb_path {lmdb_folder}
```
### DMD Training
```
torchrun --nnodes=8 --nproc_per_node=8 \
--rdzv_id=5235 \
--rdzv_backend=c10d \
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
train.py \
--config_path configs/self_forcing_14b_i2v_dmd.yaml \
--logdir logs/self_forcing_14b_i2v_dmd \
--no_visualize \
--disable-wandb
```
Our training run uses 1000 iterations and completes in under 12 hours using 64 H100 GPUs.
## Acknowledgements
This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid), [Self-Forcing](https://github.com/guandeh17/Self-Forcing) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
================================================
FILE: configs/default_config.yaml
================================================
independent_first_frame: false
warp_denoising_step: false
weight_decay: 0.01
same_step_across_blocks: true
discriminator_lr_multiplier: 1.0
last_step_only: false
i2v: false
num_training_frames: 21
gc_interval: 100
context_noise: 0
causal: true
ckpt_step: 0
prompt_name: MovieGenVideoBench
prompt_path: prompts/MovieGenVideoBench.txt
eval_first_n: 64
num_samples: 1
height: 480
width: 832
num_frames: 81
================================================
FILE: configs/self_forcing_14b_dmd.yaml
================================================
# generator_ckpt: checkpoints/ode_init.pt
i2v: false
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
fake_name: Wan2.1-T2V-14B
generator_type: bidirectional
generator_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
text_encoder_cpu_offload: false
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 4.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: WANDB_HOST
wandb_key: WANDB_KEY
wandb_entity: WANDB_ENTITY
wandb_project: WANDB_PROJECT
sharding_strategy: full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
data_type: text_folder
data_path: prompts/good_prompts/
data_max_count: 30000
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 64
log_iters: 200
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
================================================
FILE: configs/self_forcing_14b_i2v_dmd.yaml
================================================
# generator_ckpt: checkpoints/ode_init.pt
i2v: true
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-I2V-14B-480P
fake_name: Wan2.1-I2V-14B-480P
generator_type: bidirectional
generator_name: Wan2.1-I2V-14B-480P
text_encoder_fsdp_wrap_strategy: size
image_encoder_fsdp_wrap_strategy: size
text_encoder_cpu_offload: true
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 6.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: WANDB_HOST
wandb_key: WANDB_KEY
wandb_entity: WANDB_ENTITY
wandb_project: WANDB_PROJECT
sharding_strategy: full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
data_type: text_folder
data_path: /data/mydataset/output_lmdb/
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 64
log_iters: 100
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
================================================
FILE: configs/self_forcing_dmd.yaml
================================================
generator_ckpt: checkpoints/ode_init.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
fake_name: Wan2.1-T2V-1.3B
generator_type: causal
generator_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: WANDB_HOST
wandb_key: WANDB_KEY
wandb_entity: WANDB_ENTITY
wandb_project: WANDB_PROJECT
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
data_type: text_folder
data_path: prompts/vidprom_filtered_extended.txt
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 64
log_iters: 50
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
================================================
FILE: configs/self_forcing_sid.yaml
================================================
generator_ckpt: checkpoints/ode_init.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: WANDB_HOST
wandb_key: WANDB_KEY
wandb_entity: WANDB_ENTITY
wandb_project: WANDB_PROJECT
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 2.0e-06
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
weight_decay: 0.0
data_path: prompts/vidprom_filtered_extended.txt
batch_size: 1
sid_alpha: 1.0
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 64
log_iters: 50
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
timestep_shift: 5.0
================================================
FILE: convert_checkpoint.py
================================================
import torch
import argparse
import os
import gc
from safetensors.torch import save_file
def main():
# Set up argument parser
parser = argparse.ArgumentParser(description='Extract and save the generator part from a checkpoint.')
parser.add_argument('--input-checkpoint', type=str, required=True, help='Path to the input checkpoint file')
parser.add_argument('--output-checkpoint', type=str, required=True, help='Path to save the output checkpoint file')
parser.add_argument('--remove-prefix', type=str, nargs='?', const="model.", default="model.", help='Prefix to remove from keys (default: "model.")')
parser.add_argument('--to-bf16', action='store_true', help='Convert model weights to bfloat16')
parser.add_argument('--ema', action='store_true', help='Use EMA weights')
args = parser.parse_args()
# Extract arguments
input_path = args.input_checkpoint
output_path = args.output_checkpoint
prefix_to_remove = args.remove_prefix
convert_to_bf16 = args.to_bf16
use_ema = args.ema
# Check if input file exists
if not os.path.exists(input_path):
print(f"Error: Input checkpoint file not found: {input_path}")
return
# Load the input checkpoint
print(f"Loading checkpoint from {input_path}...")
checkpoint = torch.load(input_path, map_location=torch.device('cpu'))
model_type = "generator_ema" if use_ema else "generator"
# Check if 'generator' key exists
if model_type not in checkpoint:
print(f"Error: The '{model_type}' key does not exist in the input checkpoint")
return
# Extract the generator
generator = checkpoint[model_type]
print(f"Successfully extracted '{model_type}' from input checkpoint")
# Remove the specified prefix from keys
new_generator = {}
prefix_count = 0
tensor_count = 0
for key, value in generator.items():
# Process key - remove prefix if needed
if key.startswith(prefix_to_remove):
new_key = key[len(prefix_to_remove):] # Remove the prefix
prefix_count += 1
else:
new_key = key
new_key = new_key.replace("_fsdp_wrapped_module.", "").replace("_checkpoint_wrapped_module.", "").replace("_orig_mod.", "")
print(f"{key} -> {new_key}")
# Convert tensor to bf16 if requested
if convert_to_bf16 and isinstance(value, torch.Tensor) and value.is_floating_point():
value = value.to(torch.bfloat16)
tensor_count += 1
new_generator[new_key] = value
# Print processing summary
print(f"Removed prefix '{prefix_to_remove}' from {prefix_count} keys")
if convert_to_bf16:
print(f"Converted {tensor_count} tensors to bfloat16")
del checkpoint
gc.collect()
# Save the new checkpoint
print(f"Saving generator to {output_path}...")
# Choose save method based on file extension
if output_path.endswith('.safetensors'):
save_file(new_generator, output_path)
print(f"Successfully saved generator to {output_path} (safetensors format)")
elif output_path.endswith('.pt') or output_path.endswith('.pth'):
torch.save(new_generator, output_path)
print(f"Successfully saved generator to {output_path} (PyTorch format)")
else:
# Default to PyTorch format
torch.save(new_generator, output_path)
print(f"Successfully saved generator to {output_path} (PyTorch format - default)")
if __name__ == "__main__":
main()
================================================
FILE: demo.py
================================================
"""
Demo for Self-Forcing.
"""
import os
import time
import base64
import argparse
import urllib.request
from io import BytesIO
from PIL import Image
import numpy as np
import torch
from omegaconf import OmegaConf
from flask import Flask, render_template, jsonify
from flask_socketio import SocketIO, emit
import queue
from threading import Thread, Event
from pipeline import CausalInferencePipeline
from demo_utils.constant import ZERO_VAE_CACHE
from demo_utils.vae_block3 import VAEDecoderWrapper
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
from demo_utils.utils import generate_timestamp
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=5001)
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt')
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml')
parser.add_argument('--trt', action='store_true')
args = parser.parse_args()
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
low_memory = get_cuda_free_memory_gb(gpu) < 40
# Load models
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
text_encoder = WanTextEncoder()
# Global variables for dynamic model switching
current_vae_decoder = None
current_use_taehv = False
fp8_applied = False
torch_compile_applied = False
def initialize_vae_decoder(use_taehv=False, use_trt=False):
"""Initialize VAE decoder based on the selected option"""
global current_vae_decoder, current_use_taehv
if use_trt:
from demo_utils.vae import VAETRTWrapper
current_vae_decoder = VAETRTWrapper()
return current_vae_decoder
if use_taehv:
from demo_utils.taehv import TAEHV
# Check if taew2_1.pth exists in checkpoints folder, download if missing
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
if not os.path.exists(taehv_checkpoint_path):
print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
os.makedirs("checkpoints", exist_ok=True)
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
try:
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
except Exception as e:
print(f"Failed to download taew2_1.pth: {e}")
raise
class DotDict(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
class TAEHVDiffusersWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.float16
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
self.config = DotDict(scaling_factor=1.0)
def decode(self, latents, return_dict=None):
# n, c, t, h, w = latents.shape
# low-memory, set parallel=True for faster + higher memory
return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
current_vae_decoder = TAEHVDiffusersWrapper()
else:
current_vae_decoder = VAEDecoderWrapper()
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
if 'decoder.' in key or 'conv2' in key:
decoder_state_dict[key] = value
current_vae_decoder.load_state_dict(decoder_state_dict)
current_vae_decoder.eval()
current_vae_decoder.to(dtype=torch.float16)
current_vae_decoder.requires_grad_(False)
current_vae_decoder.to(gpu)
current_use_taehv = use_taehv
print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
return current_vae_decoder
# Initialize with default VAE
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
transformer = WanDiffusionWrapper(is_causal=True)
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
transformer.load_state_dict(state_dict['generator_ema'])
text_encoder.eval()
transformer.eval()
transformer.to(dtype=torch.float16)
text_encoder.to(dtype=torch.bfloat16)
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)
pipeline = CausalInferencePipeline(
config,
device=gpu,
generator=transformer,
text_encoder=text_encoder,
vae=vae_decoder
)
if low_memory:
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
else:
text_encoder.to(gpu)
transformer.to(gpu)
# Flask and SocketIO setup
app = Flask(__name__)
app.config['SECRET_KEY'] = 'frontend_buffered_demo'
socketio = SocketIO(app, cors_allowed_origins="*")
generation_active = False
stop_event = Event()
frame_send_queue = queue.Queue()
sender_thread = None
models_compiled = False
def tensor_to_base64_frame(frame_tensor):
"""Convert a single frame tensor to base64 image string."""
# Clamp and normalize to 0-255
frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
frame = frame.to(torch.uint8).cpu().numpy()
# CHW -> HWC
if len(frame.shape) == 3:
frame = np.transpose(frame, (1, 2, 0))
# Convert to PIL Image
if frame.shape[2] == 3: # RGB
image = Image.fromarray(frame, 'RGB')
else: # Handle other formats
image = Image.fromarray(frame)
# Convert to base64
buffer = BytesIO()
image.save(buffer, format='JPEG', quality=85)
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def frame_sender_worker():
"""Background thread that processes frame send queue non-blocking."""
global frame_send_queue, generation_active, stop_event
print("📡 Frame sender thread started")
while True:
frame_data = None
try:
# Get frame data from queue
frame_data = frame_send_queue.get(timeout=1.0)
if frame_data is None: # Shutdown signal
frame_send_queue.task_done() # Mark shutdown signal as done
break
frame_tensor, frame_index, block_index, job_id = frame_data
# Convert tensor to base64
base64_frame = tensor_to_base64_frame(frame_tensor)
# Send via SocketIO
try:
socketio.emit('frame_ready', {
'data': base64_frame,
'frame_index': frame_index,
'block_index': block_index,
'job_id': job_id
})
except Exception as e:
print(f"⚠️ Failed to send frame {frame_index}: {e}")
frame_send_queue.task_done()
except queue.Empty:
# Check if we should continue running
if not generation_active and frame_send_queue.empty():
break
except Exception as e:
print(f"❌ Frame sender error: {e}")
# Make sure to mark task as done even if there's an error
if frame_data is not None:
try:
frame_send_queue.task_done()
except Exception as e:
print(f"❌ Failed to mark frame task as done: {e}")
break
print("📡 Frame sender thread stopped")
@torch.no_grad()
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
"""Generate video and push frames immediately to frontend."""
global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv
try:
generation_active = True
stop_event.clear()
job_id = generate_timestamp()
# Start frame sender thread if not already running
if sender_thread is None or not sender_thread.is_alive():
sender_thread = Thread(target=frame_sender_worker, daemon=True)
sender_thread.start()
# Emit progress updates
def emit_progress(message, progress):
try:
socketio.emit('progress', {
'message': message,
'progress': progress,
'job_id': job_id
})
except Exception as e:
print(f"❌ Failed to emit progress: {e}")
emit_progress('Starting generation...', 0)
# Handle VAE decoder switching
if use_taehv != current_use_taehv:
emit_progress('Switching VAE decoder...', 2)
print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
# Update pipeline with new VAE decoder
pipeline.vae = current_vae_decoder
# Handle FP8 quantization
if enable_fp8 and not fp8_applied:
emit_progress('Applying FP8 quantization...', 3)
print("🔧 Applying FP8 quantization to transformer")
from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
fp8_applied = True
# Text encoding
emit_progress('Encoding text prompt...', 8)
conditional_dict = text_encoder(text_prompts=[prompt])
for key, value in conditional_dict.items():
conditional_dict[key] = value.to(dtype=torch.float16)
if low_memory:
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
move_model_to_device_with_memory_preservation(
text_encoder, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
# Handle torch.compile if enabled
torch_compile_applied = enable_torch_compile
if enable_torch_compile and not models_compiled:
# Compile transformer and decoder
transformer.compile(mode="max-autotune-no-cudagraphs")
if not current_use_taehv and not low_memory and not args.trt:
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
# Initialize generation
emit_progress('Initializing generation...', 12)
rnd = torch.Generator(gpu).manual_seed(seed)
# all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)
pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
# Generation parameters
num_blocks = 7
current_start_frame = 0
num_input_frames = 0
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
if current_use_taehv:
vae_cache = None
else:
vae_cache = ZERO_VAE_CACHE
for i in range(len(vae_cache)):
vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
total_frames_sent = 0
generation_start_time = time.time()
emit_progress('Generating frames... (frontend handles timing)', 15)
for idx, current_num_frames in enumerate(all_num_frames):
if not generation_active or stop_event.is_set():
break
progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
# Special message for first block with torch.compile
if idx == 0 and torch_compile_applied and not models_compiled:
emit_progress(
f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
models_compiled = True
else:
emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")
block_start_time = time.time()
noisy_input = noise[:, current_start_frame -
num_input_frames:current_start_frame + current_num_frames - num_input_frames]
# Denoising loop
denoising_start = time.time()
for index, current_timestep in enumerate(pipeline.denoising_step_list):
if not generation_active or stop_event.is_set():
break
timestep = torch.ones([1, current_num_frames], device=noise.device,
dtype=torch.int64) * current_timestep
if index < len(pipeline.denoising_step_list) - 1:
_, denoised_pred = transformer(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=current_start_frame * pipeline.frame_seq_length
)
next_timestep = pipeline.denoising_step_list[index + 1]
noisy_input = pipeline.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
).unflatten(0, denoised_pred.shape[:2])
else:
_, denoised_pred = transformer(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=current_start_frame * pipeline.frame_seq_length
)
if not generation_active or stop_event.is_set():
break
denoising_time = time.time() - denoising_start
print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")
# Record output
# all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
# Update KV cache for next block
if idx != len(all_num_frames) - 1:
transformer(
noisy_image_or_video=denoised_pred,
conditional_dict=conditional_dict,
timestep=torch.zeros_like(timestep),
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=current_start_frame * pipeline.frame_seq_length,
)
# Decode to pixels and send frames immediately
print(f"🎨 Decoding block {idx+1} to pixels...")
decode_start = time.time()
if args.trt:
all_current_pixels = []
for i in range(denoised_pred.shape[1]):
is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
torch.tensor(0.0).cuda().half()
outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
# outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
current_pixels, vae_cache = outputs[0], outputs[1:]
print(current_pixels.max(), current_pixels.min())
all_current_pixels.append(current_pixels.clone())
pixels = torch.cat(all_current_pixels, dim=1)
if idx == 0:
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
else:
if current_use_taehv:
if vae_cache is None:
vae_cache = denoised_pred
else:
denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
vae_cache = denoised_pred[:, -3:, :, :, :]
pixels = current_vae_decoder.decode(denoised_pred)
print(f"denoised_pred shape: {denoised_pred.shape}")
print(f"pixels shape: {pixels.shape}")
if idx == 0:
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
else:
pixels = pixels[:, 12:, :, :, :]
else:
pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
if idx == 0:
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
decode_time = time.time() - decode_start
print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
# Queue frames for non-blocking sending
block_frames = pixels.shape[1]
print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
queue_start = time.time()
for frame_idx in range(block_frames):
if not generation_active or stop_event.is_set():
break
frame_tensor = pixels[0, frame_idx].cpu()
# Queue frame data in non-blocking way
frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
total_frames_sent += 1
queue_time = time.time() - queue_start
block_time = time.time() - block_start_time
print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
current_start_frame += current_num_frames
generation_time = time.time() - generation_start_time
print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
# Wait for all frames to be sent before completing
emit_progress('Waiting for all frames to be sent...', 97)
print("⏳ Waiting for all frames to be sent...")
frame_send_queue.join() # Wait for all queued frames to be processed
print("✅ All frames sent successfully!")
# Final progress update
emit_progress('Generation complete!', 100)
try:
socketio.emit('generation_complete', {
'message': 'Video generation completed!',
'total_frames': total_frames_sent,
'generation_time': f"{generation_time:.2f}s",
'job_id': job_id
})
except Exception as e:
print(f"❌ Failed to emit generation complete: {e}")
except Exception as e:
print(f"❌ Generation failed: {e}")
try:
socketio.emit('error', {
'message': f'Generation failed: {str(e)}',
'job_id': job_id
})
except Exception as e:
print(f"❌ Failed to emit error: {e}")
finally:
generation_active = False
stop_event.set()
# Clean up sender thread
try:
frame_send_queue.put(None)
except Exception as e:
print(f"❌ Failed to put None in frame_send_queue: {e}")
# Socket.IO event handlers
@socketio.on('connect')
def handle_connect():
print('Client connected')
emit('status', {'message': 'Connected to frontend-buffered demo server'})
@socketio.on('disconnect')
def handle_disconnect():
print('Client disconnected')
@socketio.on('start_generation')
def handle_start_generation(data):
global generation_active
if generation_active:
emit('error', {'message': 'Generation already in progress'})
return
prompt = data.get('prompt', '')
seed = data.get('seed', 31337)
enable_torch_compile = data.get('enable_torch_compile', False)
enable_fp8 = data.get('enable_fp8', False)
use_taehv = data.get('use_taehv', False)
if not prompt:
emit('error', {'message': 'Prompt is required'})
return
# Start generation in background thread
socketio.start_background_task(generate_video_stream, prompt, seed,
enable_torch_compile, enable_fp8, use_taehv)
emit('status', {'message': 'Generation started - frames will be sent immediately'})
@socketio.on('stop_generation')
def handle_stop_generation():
global generation_active, stop_event, frame_send_queue
generation_active = False
stop_event.set()
# Signal sender thread to stop (will be processed after current frames)
try:
frame_send_queue.put(None)
except Exception as e:
print(f"❌ Failed to put None in frame_send_queue: {e}")
emit('status', {'message': 'Generation stopped'})
# Web routes
@app.route('/')
def index():
return render_template('demo.html')
@app.route('/api/status')
def api_status():
return jsonify({
'generation_active': generation_active,
'free_vram_gb': get_cuda_free_memory_gb(gpu),
'fp8_applied': fp8_applied,
'torch_compile_applied': torch_compile_applied,
'current_use_taehv': current_use_taehv
})
if __name__ == '__main__':
print(f"🚀 Starting demo on http://{args.host}:{args.port}")
socketio.run(app, host=args.host, port=args.port, debug=False)
================================================
FILE: demo_utils/constant.py
================================================
import torch
ZERO_VAE_CACHE = [
torch.zeros(1, 16, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 384, 2, 60, 104),
torch.zeros(1, 192, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 384, 2, 120, 208),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 192, 2, 240, 416),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832),
torch.zeros(1, 96, 2, 480, 832)
]
feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names
================================================
FILE: demo_utils/memory.py
================================================
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel
import torch
cpu = torch.device('cpu')
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
gpu_complete_modules = []
class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, **kwargs):
original_class = module.__class__
module.__dict__['forge_backup_original_class'] = original_class
def hacked_get_attr(self, name: str):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
p = _parameters[name]
if p is None:
return None
if p.__class__ == torch.nn.Parameter:
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
else:
return p.to(**kwargs)
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name].to(**kwargs)
return super(original_class, self).__getattr__(name)
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
'__getattr__': hacked_get_attr,
})
return
@staticmethod
def _uninstall_module(module: torch.nn.Module):
if 'forge_backup_original_class' in module.__dict__:
module.__class__ = module.__dict__.pop('forge_backup_original_class')
return
@staticmethod
def install_model(model: torch.nn.Module, **kwargs):
for m in model.modules():
DynamicSwapInstaller._install_module(m, **kwargs)
return
@staticmethod
def uninstall_model(model: torch.nn.Module):
for m in model.modules():
DynamicSwapInstaller._uninstall_module(m)
return
def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
if hasattr(model, 'scale_shift_table'):
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
return
for k, p in model.named_modules():
if hasattr(p, 'weight'):
p.to(target_device)
return
def get_cuda_free_memory_gb(device=None):
if device is None:
device = gpu
memory_stats = torch.cuda.memory_stats(device)
bytes_active = memory_stats['active_bytes.all.current']
bytes_reserved = memory_stats['reserved_bytes.all.current']
bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
bytes_inactive_reserved = bytes_reserved - bytes_active
bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
return bytes_total_available / (1024 ** 3)
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
for m in model.modules():
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
torch.cuda.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=target_device)
model.to(device=target_device)
torch.cuda.empty_cache()
return
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
for m in model.modules():
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
torch.cuda.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=cpu)
model.to(device=cpu)
torch.cuda.empty_cache()
return
def unload_complete_models(*args):
for m in gpu_complete_modules + list(args):
m.to(device=cpu)
print(f'Unloaded {m.__class__.__name__} as complete.')
gpu_complete_modules.clear()
torch.cuda.empty_cache()
return
def load_model_as_complete(model, target_device, unload=True):
if unload:
unload_complete_models()
model.to(device=target_device)
print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
gpu_complete_modules.append(model)
return
================================================
FILE: demo_utils/taehv.py
================================================
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N * T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i + 1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
if checkpoint_path is not None:
self.load_state_dict(self.patch_tgrow_layers(torch.load(
checkpoint_path, map_location="cpu", weights_only=True)))
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if sd[key].shape[0] > new_sd[key].shape[0]:
# take the last-timestep output channels
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=False):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
# return x[:, self.frames_to_trim:]
return x
def forward(self, x):
return self.c(x)
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import os
import sys
import cv2 # no highly esteemed deed is commemorated here
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, 'writer'):
self.writer.release()
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
print(
f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None]
vid_dev = video.to(dev, dtype).div_(255.0)
# convert to device tensor
if video.numel() < 100_000_000:
print(f" {video_path} seems small enough, will process all frames in parallel")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc)
print(f" Decoded {video_path} -> {vid_dec.shape}")
else:
print(f" {video_path} seems large, will process each frame sequentially")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev, parallel=False)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
video_out = VideoTensorWriter(
video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
print(f" Saved to {video_out_path}")
if __name__ == "__main__":
main()
================================================
FILE: demo_utils/utils.py
================================================
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel
import os
import cv2
import json
import random
import glob
import torch
import einops
import numpy as np
import datetime
import torchvision
from PIL import Image
def min_resize(x, m):
if x.shape[0] < x.shape[1]:
s0 = m
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
else:
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
s1 = m
new_max = max(s1, s0)
raw_max = max(x.shape[0], x.shape[1])
if new_max < raw_max:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
return y
def d_resize(x, y):
H, W, C = y.shape
new_min = min(H, W)
raw_min = min(x.shape[0], x.shape[1])
if new_min < raw_min:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (W, H), interpolation=interpolation)
return y
def resize_and_center_crop(image, target_width, target_height):
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return np.array(cropped_image)
def resize_and_center_crop_pytorch(image, target_width, target_height):
B, C, H, W = image.shape
if H == target_height and W == target_width:
return image
scale_factor = max(target_width / W, target_height / H)
resized_width = int(round(W * scale_factor))
resized_height = int(round(H * scale_factor))
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
top = (resized_height - target_height) // 2
left = (resized_width - target_width) // 2
cropped = resized[:, :, top:top + target_height, left:left + target_width]
return cropped
def resize_without_crop(image, target_width, target_height):
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
return np.array(resized_image)
def just_crop(image, w, h):
if h == image.shape[0] and w == image.shape[1]:
return image
original_height, original_width = image.shape[:2]
k = min(original_height / h, original_width / w)
new_width = int(round(w * k))
new_height = int(round(h * k))
x_start = (original_width - new_width) // 2
y_start = (original_height - new_height) // 2
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
return cropped_image
def write_to_json(data, file_path):
temp_file_path = file_path + ".tmp"
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
json.dump(data, temp_file, indent=4)
os.replace(temp_file_path, file_path)
return
def read_from_json(file_path):
with open(file_path, 'rt', encoding='utf-8') as file:
data = json.load(file)
return data
def get_active_parameters(m):
return {k: v for k, v in m.named_parameters() if v.requires_grad}
def cast_training_params(m, dtype=torch.float32):
result = {}
for n, param in m.named_parameters():
if param.requires_grad:
param.data = param.to(dtype)
result[n] = param
return result
def separate_lora_AB(parameters, B_patterns=None):
parameters_normal = {}
parameters_B = {}
if B_patterns is None:
B_patterns = ['.lora_B.', '__zero__']
for k, v in parameters.items():
if any(B_pattern in k for B_pattern in B_patterns):
parameters_B[k] = v
else:
parameters_normal[k] = v
return parameters_normal, parameters_B
def set_attr_recursive(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
return
def print_tensor_list_size(tensors):
total_size = 0
total_elements = 0
if isinstance(tensors, dict):
tensors = tensors.values()
for tensor in tensors:
total_size += tensor.nelement() * tensor.element_size()
total_elements += tensor.nelement()
total_size_MB = total_size / (1024 ** 2)
total_elements_B = total_elements / 1e9
print(f"Total number of tensors: {len(tensors)}")
print(f"Total size of tensors: {total_size_MB:.2f} MB")
print(f"Total number of parameters: {total_elements_B:.3f} billion")
return
@torch.no_grad()
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
batch_size = a.size(0)
if b is None:
b = torch.zeros_like(a)
if mask_a is None:
mask_a = torch.rand(batch_size) < probability_a
mask_a = mask_a.to(a.device)
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
result = torch.where(mask_a, a, b)
return result
@torch.no_grad()
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
@torch.no_grad()
def supress_lower_channels(m, k, alpha=0.01):
data = m.weight.data.clone()
assert int(data.shape[1]) >= k
data[:, :k] = data[:, :k] * alpha
m.weight.data = data.contiguous().clone()
return m
def freeze_module(m):
if not hasattr(m, '_forward_inside_frozen_module'):
m._forward_inside_frozen_module = m.forward
m.requires_grad_(False)
m.forward = torch.no_grad()(m.forward)
return m
def get_latest_safetensors(folder_path):
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
if not safetensors_files:
raise ValueError('No file to resume!')
latest_file = max(safetensors_files, key=os.path.getmtime)
latest_file = os.path.abspath(os.path.realpath(latest_file))
return latest_file
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
tags = tags_str.split(', ')
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
prompt = ', '.join(tags)
return prompt
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
if round_to_int:
numbers = np.round(numbers).astype(int)
return numbers.tolist()
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
edges = np.linspace(0, 1, n + 1)
points = np.random.uniform(edges[:-1], edges[1:])
numbers = inclusive + (exclusive - inclusive) * points
if round_to_int:
numbers = np.round(numbers).astype(int)
return numbers.tolist()
def soft_append_bcthw(history, current, overlap=0):
if overlap <= 0:
return torch.cat([history, current], dim=2)
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
return output.to(history)
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
b, c, t, h, w = x.shape
per_row = b
for p in [6, 5, 4, 3, 2]:
if b % p == 0:
per_row = p
break
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
return x
def save_bcthw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def save_bchw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c h w -> c h (b w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def add_tensors_with_padding(tensor1, tensor2):
if tensor1.shape == tensor2.shape:
return tensor1 + tensor2
shape1 = tensor1.shape
shape2 = tensor2.shape
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
padded_tensor1 = torch.zeros(new_shape)
padded_tensor2 = torch.zeros(new_shape)
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
result = padded_tensor1 + padded_tensor2
return result
def print_free_mem():
torch.cuda.empty_cache()
free_mem, total_mem = torch.cuda.mem_get_info(0)
free_mem_mb = free_mem / (1024 ** 2)
total_mem_mb = total_mem / (1024 ** 2)
print(f"Free memory: {free_mem_mb:.2f} MB")
print(f"Total memory: {total_mem_mb:.2f} MB")
return
def print_gpu_parameters(device, state_dict, log_count=1):
summary = {"device": device, "keys_count": len(state_dict)}
logged_params = {}
for i, (key, tensor) in enumerate(state_dict.items()):
if i >= log_count:
break
logged_params[key] = tensor.flatten()[:3].tolist()
summary["params"] = logged_params
print(str(summary))
return
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
from PIL import Image, ImageDraw, ImageFont
txt = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype(font_path, size=size)
if text == '':
return np.array(txt)
# Split text into lines that fit within the image width
lines = []
words = text.split()
current_line = words[0]
for word in words[1:]:
line_with_word = f"{current_line} {word}"
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
current_line = line_with_word
else:
lines.append(current_line)
current_line = word
lines.append(current_line)
# Draw the text line by line
y = 0
line_height = draw.textbbox((0, 0), "A", font=font)[3]
for line in lines:
if y + line_height > height:
break # stop drawing if the next line will be outside the image
draw.text((0, y), line, fill="black", font=font)
y += line_height
return np.array(txt)
def blue_mark(x):
x = x.copy()
c = x[:, :, 2]
b = cv2.blur(c, (9, 9))
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
return x
def green_mark(x):
x = x.copy()
x[:, :, 2] = -1
x[:, :, 0] = -1
return x
def frame_mark(x):
x = x.copy()
x[:64] = -1
x[-64:] = -1
x[:, :8] = 1
x[:, -8:] = 1
return x
@torch.inference_mode()
def pytorch2numpy(imgs):
results = []
for x in imgs:
y = x.movedim(0, -1)
y = y * 127.5 + 127.5
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
results.append(y)
return results
@torch.inference_mode()
def numpy2pytorch(imgs):
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
h = h.movedim(-1, 1)
return h
@torch.no_grad()
def duplicate_prefix_to_suffix(x, count, zero_out=False):
if zero_out:
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
else:
return torch.cat([x, x[:count]], dim=0)
def weighted_mse(a, b, weight):
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
x = (x - x_min) / (x_max - x_min)
x = max(0.0, min(x, 1.0))
x = x ** sigma
return y_min + x * (y_max - y_min)
def expand_to_dims(x, target_dims):
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
if tensor is None:
return None
first_dim = tensor.shape[0]
if first_dim == batch_size:
return tensor
if batch_size % first_dim != 0:
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
repeat_times = batch_size // first_dim
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
def dim5(x):
return expand_to_dims(x, 5)
def dim4(x):
return expand_to_dims(x, 4)
def dim3(x):
return expand_to_dims(x, 3)
def crop_or_pad_yield_mask(x, length):
B, F, C = x.shape
device = x.device
dtype = x.dtype
if F < length:
y = torch.zeros((B, length, C), dtype=dtype, device=device)
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
y[:, :F, :] = x
mask[:, :F] = True
return y, mask
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
def extend_dim(x, dim, minimal_length, zero_pad=False):
original_length = int(x.shape[dim])
if original_length >= minimal_length:
return x
if zero_pad:
padding_shape = list(x.shape)
padding_shape[dim] = minimal_length - original_length
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
else:
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
last_element = x[idx]
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
return torch.cat([x, padding], dim=dim)
def lazy_positional_encoding(t, repeats=None):
if not isinstance(t, list):
t = [t]
from diffusers.models.embeddings import get_timestep_embedding
te = torch.tensor(t)
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
if repeats is None:
return te
te = te[:, None, :].expand(-1, repeats, -1)
return te
def state_dict_offset_merge(A, B, C=None):
result = {}
keys = A.keys()
for key in keys:
A_value = A[key]
B_value = B[key].to(A_value)
if C is None:
result[key] = A_value + B_value
else:
C_value = C[key].to(A_value)
result[key] = A_value + B_value - C_value
return result
def state_dict_weighted_merge(state_dicts, weights):
if len(state_dicts) != len(weights):
raise ValueError("Number of state dictionaries must match number of weights")
if not state_dicts:
return {}
total_weight = sum(weights)
if total_weight == 0:
raise ValueError("Sum of weights cannot be zero")
normalized_weights = [w / total_weight for w in weights]
keys = state_dicts[0].keys()
result = {}
for key in keys:
result[key] = state_dicts[0][key] * normalized_weights[0]
for i in range(1, len(state_dicts)):
state_dict_value = state_dicts[i][key].to(result[key])
result[key] += state_dict_value * normalized_weights[i]
return result
def group_files_by_folder(all_files):
grouped_files = {}
for file in all_files:
folder_name = os.path.basename(os.path.dirname(file))
if folder_name not in grouped_files:
grouped_files[folder_name] = []
grouped_files[folder_name].append(file)
list_of_lists = list(grouped_files.values())
return list_of_lists
def generate_timestamp():
now = datetime.datetime.now()
timestamp = now.strftime('%y%m%d_%H%M%S')
milliseconds = f"{int(now.microsecond / 1000):03d}"
random_number = random.randint(0, 9999)
return f"{timestamp}_{milliseconds}_{random_number}"
def write_PIL_image_with_png_info(image, metadata, path):
from PIL.PngImagePlugin import PngInfo
png_info = PngInfo()
for key, value in metadata.items():
png_info.add_text(key, value)
image.save(path, "PNG", pnginfo=png_info)
return image
def torch_safe_save(content, path):
torch.save(content, path + '_tmp')
os.replace(path + '_tmp', path)
return path
def move_optimizer_to_device(optimizer, device):
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
================================================
FILE: demo_utils/vae.py
================================================
from typing import List
from einops import rearrange
import tensorrt as trt
import torch
import torch.nn as nn
from demo_utils.constant import ALL_INPUTS_NAMES, ZERO_VAE_CACHE
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, Upsample
CACHE_T = 2
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache_1, feat_cache_2):
h = self.shortcut(x)
feat_cache = feat_cache_1
out_feat_cache = []
for layer in self.residual:
if isinstance(layer, CausalConv3d):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache)
out_feat_cache.append(cache_x)
feat_cache = feat_cache_2
else:
x = layer(x)
return x + h, *out_feat_cache
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, is_first_frame, feat_cache):
if self.mode == 'upsample3d':
b, c, t, h, w = x.size()
# x, out_feat_cache = torch.cond(
# is_first_frame,
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
# lambda: self.temporal_conv(x, feat_cache),
# )
# x, out_feat_cache = torch.cond(
# is_first_frame,
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
# lambda: self.temporal_conv(x, feat_cache),
# )
x, out_feat_cache = self.temporal_conv(x, is_first_frame, feat_cache)
out_feat_cache = torch.cond(
is_first_frame,
lambda: feat_cache.clone().contiguous(),
lambda: out_feat_cache.clone().contiguous(),
)
# if is_first_frame:
# x = torch.cat([torch.zeros_like(x), x], dim=2)
# out_feat_cache = feat_cache.clone()
# else:
# x, out_feat_cache = self.temporal_conv(x, feat_cache)
else:
out_feat_cache = None
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x, out_feat_cache
def temporal_conv(self, x, is_first_frame, feat_cache):
b, c, t, h, w = x.size()
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache is not None:
cache_x = torch.cat([
torch.zeros_like(cache_x),
cache_x
], dim=2)
x = torch.cond(
is_first_frame,
lambda: torch.cat([torch.zeros_like(x), x], dim=1).contiguous(),
lambda: self.time_conv(x, feat_cache).contiguous(),
)
# x = self.time_conv(x, feat_cache)
out_feat_cache = cache_x
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
return x.contiguous(), out_feat_cache.contiguous()
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class VAEDecoderWrapperSingle(nn.Module):
def __init__(self):
super().__init__()
self.decoder = VAEDecoder3d()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
self.z_dim = 16
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
def forward(
self,
z: torch.Tensor,
is_first_frame: torch.Tensor,
*feat_cache: List[torch.Tensor]
):
# from [batch_size, num_frames, num_channels, height, width]
# to [batch_size, num_channels, num_frames, height, width]
z = z.permute(0, 2, 1, 3, 4)
assert z.shape[2] == 1
feat_cache = list(feat_cache)
is_first_frame = is_first_frame.bool()
device, dtype = z.device, z.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
x = self.conv2(z)
out, feat_cache = self.decoder(x, is_first_frame, feat_cache=feat_cache)
out = out.clamp_(-1, 1)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
out = out.permute(0, 2, 1, 3, 4)
return out, feat_cache
class VAEDecoder3d(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.cache_t = 2
self.decoder_conv_num = 32
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(
self,
x: torch.Tensor,
is_first_frame: torch.Tensor,
feat_cache: List[torch.Tensor]
):
idx = 0
out_feat_cache = []
# conv1
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
out_feat_cache.append(cache_x)
idx += 1
# middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
idx += 2
out_feat_cache.append(out_feat_cache_1)
out_feat_cache.append(out_feat_cache_2)
else:
x = layer(x)
# upsamples
for layer in self.upsamples:
if isinstance(layer, Resample):
x, cache_x = layer(x, is_first_frame, feat_cache[idx])
if cache_x is not None:
out_feat_cache.append(cache_x)
idx += 1
else:
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
idx += 2
out_feat_cache.append(out_feat_cache_1)
out_feat_cache.append(out_feat_cache_2)
# head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
out_feat_cache.append(cache_x)
idx += 1
else:
x = layer(x)
return x, out_feat_cache
class VAETRTWrapper():
def __init__(self):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with open("checkpoints/vae_decoder_int8.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
self.engine: trt.ICudaEngine = rt.deserialize_cuda_engine(f.read())
self.context: trt.IExecutionContext = self.engine.create_execution_context()
self.stream = torch.cuda.current_stream().cuda_stream
# ──────────────────────────────
# 2️⃣ Feed the engine with tensors
# (name-based API in TRT ≥10)
# ──────────────────────────────
self.dtype_map = {
trt.float32: torch.float32,
trt.float16: torch.float16,
trt.int8: torch.int8,
trt.int32: torch.int32,
}
test_input = torch.zeros(1, 16, 1, 60, 104).cuda().half()
is_first_frame = torch.tensor(1.0).cuda().half()
test_cache_inputs = [c.cuda().half() for c in ZERO_VAE_CACHE]
test_inputs = [test_input, is_first_frame] + test_cache_inputs
# keep references so buffers stay alive
self.device_buffers, self.outputs = {}, []
# ---- inputs ----
for i, name in enumerate(ALL_INPUTS_NAMES):
tensor, scale = test_inputs[i], 1 / 127
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
# dynamic shapes
if -1 in self.engine.get_tensor_shape(name):
# new API :contentReference[oaicite:0]{index=0}
self.context.set_input_shape(name, tuple(tensor.shape))
# replaces bindings[] :contentReference[oaicite:1]{index=1}
self.context.set_tensor_address(name, int(tensor.data_ptr()))
self.device_buffers[name] = tensor # keep pointer alive
# ---- (after all input shapes are known) infer output shapes ----
# propagates shapes :contentReference[oaicite:2]{index=2}
self.context.infer_shapes()
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
# replaces binding_is_input :contentReference[oaicite:3]{index=3}
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = tuple(self.context.get_tensor_shape(name))
dtype = self.dtype_map[self.engine.get_tensor_dtype(name)]
out = torch.empty(shape, dtype=dtype, device="cuda").contiguous()
self.context.set_tensor_address(name, int(out.data_ptr()))
self.outputs.append(out)
self.device_buffers[name] = out
# helper to quant-convert on the fly
def quantize_if_needed(self, t, expected_dtype, scale):
if expected_dtype == trt.int8 and t.dtype != torch.int8:
t = torch.clamp((t / scale).round(), -128, 127).to(torch.int8).contiguous()
return t # keep pointer alive
def forward(self, *test_inputs):
for i, name in enumerate(ALL_INPUTS_NAMES):
tensor, scale = test_inputs[i], 1 / 127
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
self.context.set_tensor_address(name, int(tensor.data_ptr()))
self.device_buffers[name] = tensor
self.context.execute_async_v3(stream_handle=self.stream)
torch.cuda.current_stream().synchronize()
return self.outputs
================================================
FILE: demo_utils/vae_block3.py
================================================
from typing import List
from einops import rearrange
import torch
import torch.nn as nn
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, ResidualBlock, Upsample
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
self.cache_t = 2
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class VAEDecoderWrapper(nn.Module):
def __init__(self):
super().__init__()
self.decoder = VAEDecoder3d()
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
self.z_dim = 16
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
def forward(
self,
z: torch.Tensor,
*feat_cache: List[torch.Tensor]
):
# from [batch_size, num_frames, num_channels, height, width]
# to [batch_size, num_channels, num_frames, height, width]
z = z.permute(0, 2, 1, 3, 4)
feat_cache = list(feat_cache)
print("Length of feat_cache: ", len(feat_cache))
device, dtype = z.device, z.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
if i == 0:
out, feat_cache = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_cache)
else:
out_, feat_cache = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=feat_cache)
out = torch.cat([out, out_], 2)
out = out.float().clamp_(-1, 1)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
out = out.permute(0, 2, 1, 3, 4)
return out, feat_cache
class VAEDecoder3d(nn.Module):
def __init__(self,
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.cache_t = 2
self.decoder_conv_num = 32
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(
self,
x: torch.Tensor,
feat_cache: List[torch.Tensor]
):
feat_idx = [0]
# conv1
idx = feat_idx[0]
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
# middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
# upsamples
for layer in self.upsamples:
x = layer(x, feat_cache, feat_idx)
# head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -self.cache_t:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x, feat_cache
================================================
FILE: demo_utils/vae_torch2trt.py
================================================
# ---- INT8 (optional) ----
from demo_utils.vae import (
VAEDecoderWrapperSingle, # main nn.Module
ZERO_VAE_CACHE # helper constants shipped with your code base
)
import pycuda.driver as cuda # ← add
import pycuda.autoinit # noqa
import sys
from pathlib import Path
import torch
import tensorrt as trt
from utils.dataset import ShardingLMDBDataset
data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=0
)
# ─────────────────────────────────────────────────────────
# 1️⃣ Bring the PyTorch model into scope
# (all code you pasted lives in `vae_decoder.py`)
# ─────────────────────────────────────────────────────────
# --- dummy tensors (exact shapes you posted) ---
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
dummy_cache_input = [
torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
for s in ZERO_VAE_CACHE # keep exactly the same ordering
]
inputs = [dummy_input, is_first_frame, *dummy_cache_input]
# ─────────────────────────────────────────────────────────
# 2️⃣ Export → ONNX
# ─────────────────────────────────────────────────────────
model = VAEDecoderWrapperSingle().half().cuda().eval()
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
if 'decoder.' in key or 'conv2' in key:
decoder_state_dict[key] = value
model.load_state_dict(decoder_state_dict)
model = model.half().cuda().eval() # only batch dim dynamic
onnx_path = Path("vae_decoder.onnx")
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
all_inputs_names = ["z", "use_cache"] + feat_names
with torch.inference_mode():
torch.onnx.export(
model,
tuple(inputs), # must be a tuple
onnx_path.as_posix(),
input_names=all_inputs_names,
output_names=["rgb_out", "cache_out"],
opset_version=17,
do_constant_folding=True,
dynamo=True
)
print(f"✅ ONNX graph saved to {onnx_path.resolve()}")
# (Optional) quick sanity-check with ONNX-Runtime
try:
import onnxruntime as ort
sess = ort.InferenceSession(onnx_path.as_posix(),
providers=["CUDAExecutionProvider"])
ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
_ = sess.run(None, ort_inputs)
print("✅ ONNX graph is executable")
except Exception as e:
print("⚠️ ONNX check failed:", e)
# ─────────────────────────────────────────────────────────
# 3️⃣ Build the TensorRT engine
# ─────────────────────────────────────────────────────────
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
sys.exit("❌ ONNX → TRT parsing failed")
config = builder.create_builder_config()
def set_workspace(config, bytes_):
"""Version-agnostic workspace limit."""
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
config.max_workspace_size = bytes_
else: # TRT 10+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)
# …
config = builder.create_builder_config()
set_workspace(config, 4 << 30) # 4 GB
# 4 GB
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# ---- INT8 (optional) ----
# provide a calibrator if you need an INT8 engine; comment this
# block if you only care about FP16.
# ─────────────────────────────────────────────────────────
# helper: version-agnostic workspace limit
# ─────────────────────────────────────────────────────────
def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
"""
TRT < 10.x → config.max_workspace_size
TRT ≥ 10.x → config.set_memory_pool_limit(...)
"""
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
config.max_workspace_size = bytes_
else: # TRT 10+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
bytes_)
# ─────────────────────────────────────────────────────────
# (optional) INT-8 calibrator
# ─────────────────────────────────────────────────────────
# ‼ Only keep this block if you really need INT-8 ‼ # gracefully skip if PyCUDA not present
class VAECalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, loader, cache="calibration.cache", max_batches=10):
super().__init__()
self.loader = iter(loader)
self.batch_size = loader.batch_size or 1
self.max_batches = max_batches
self.count = 0
self.cache_file = cache
self.stream = cuda.Stream()
self.dev_ptrs = {}
# --- TRT 10 needs BOTH spellings ---
def get_batch_size(self):
return self.batch_size
def getBatchSize(self):
return self.batch_size
def get_batch(self, names):
if self.count >= self.max_batches:
return None
# Randomly sample a number from 1 to 10
import random
vae_idx = random.randint(0, 10)
data = next(self.loader)
latent = data['ode_latent'][0][:, :1]
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
feat_cache = ZERO_VAE_CACHE
for i in range(vae_idx):
inputs = [latent, is_first_frame, *feat_cache]
with torch.inference_mode():
outputs = model(*inputs)
latent = data['ode_latent'][0][:, i + 1:i + 2]
is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
feat_cache = outputs[1:]
# -------- ensure context is current --------
z_np = latent.cpu().numpy().astype('float32')
ptrs = [] # list[int] – one entry per name
for name in names: # <-- match TRT's binding order
if name == "z":
arr = z_np
elif name == "use_cache":
arr = is_first_frame.cpu().numpy().astype('float32')
else:
idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17
arr = feat_cache[idx].cpu().numpy().astype('float32')
if name not in self.dev_ptrs:
self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)
cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
ptrs.append(int(self.dev_ptrs[name])) # ***int() is required***
self.stream.synchronize()
self.count += 1
print(f"Calibration batch {self.count}/{self.max_batches}")
return ptrs
# --- calibration-cache helpers (both spellings) ---
def read_calibration_cache(self):
try:
with open(self.cache_file, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def readCalibrationCache(self):
return self.read_calibration_cache()
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
f.write(cache)
def writeCalibrationCache(self, cache):
self.write_calibration_cache(cache)
# ─────────────────────────────────────────────────────────
# Builder-config + optimisation profile
# ─────────────────────────────────────────────────────────
config = builder.create_builder_config()
set_workspace(config, 4 << 30) # 4 GB
# ► enable FP16 if possible
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# ► enable INT-8 (delete this block if you don’t need it)
if cuda is not None:
config.set_flag(trt.BuilderFlag.INT8)
# supply any representative batch you like – here we reuse the latent z
calib = VAECalibrator(dataloader)
# TRT-10 renamed the setter:
if hasattr(config, "set_int8_calibrator"): # TRT 10+
config.set_int8_calibrator(calib)
else: # TRT ≤ 9
config.int8_calibrator = calib
# ---- optimisation profile ----
profile = builder.create_optimization_profile()
profile.set_shape(all_inputs_names[0], # latent z
min=(1, 1, 16, 60, 104),
opt=(1, 1, 16, 60, 104),
max=(1, 1, 16, 60, 104))
profile.set_shape("use_cache", # scalar flag
min=(1,), opt=(1,), max=(1,))
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)
config.add_optimization_profile(profile)
# ─────────────────────────────────────────────────────────
# Build the engine (API changed in TRT-10)
# ─────────────────────────────────────────────────────────
print("⚙️ Building engine … (can take a minute)")
if hasattr(builder, "build_serialized_network"): # TRT 10+
serialized_engine = builder.build_serialized_network(network, config)
assert serialized_engine is not None, "build_serialized_network() failed"
plan_path = Path("checkpoints/vae_decoder_int8.trt")
plan_path.write_bytes(serialized_engine)
engine_bytes = serialized_engine # keep for smoke-test
else: # TRT ≤ 9
engine = builder.build_engine(network, config)
assert engine is not None, "build_engine() returned None"
plan_path = Path("checkpoints/vae_decoder_int8.trt")
plan_path.write_bytes(engine.serialize())
engine_bytes = engine.serialize()
print(f"✅ TensorRT engine written to {plan_path.resolve()}")
# ─────────────────────────────────────────────────────────
# 4️⃣ Quick smoke-test with the brand-new engine
# ─────────────────────────────────────────────────────────
with trt.Runtime(TRT_LOGGER) as rt:
engine = rt.deserialize_cuda_engine(engine_bytes)
context = engine.create_execution_context()
stream = torch.cuda.current_stream().cuda_stream
# pre-allocate device buffers once
device_buffers, outputs = {}, []
dtype_map = {trt.float32: torch.float32,
trt.float16: torch.float16,
trt.int8: torch.int8,
trt.int32: torch.int32}
for name, tensor in zip(all_inputs_names, inputs):
if -1 in engine.get_tensor_shape(name): # dynamic input
context.set_input_shape(name, tensor.shape)
context.set_tensor_address(name, int(tensor.data_ptr()))
device_buffers[name] = tensor
context.infer_shapes() # propagate ⇢ outputs
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = tuple(context.get_tensor_shape(name))
dtype = dtype_map[engine.get_tensor_dtype(name)]
out = torch.empty(shape, dtype=dtype, device="cuda")
context.set_tensor_address(name, int(out.data_ptr()))
outputs.append(out)
print(f"output {name} shape: {shape}")
context.execute_async_v3(stream_handle=stream)
torch.cuda.current_stream().synchronize()
print("✅ TRT execution OK – first output shape:", outputs[0].shape)
================================================
FILE: inference.py
================================================
import argparse
import torch
import os
from omegaconf import OmegaConf
from tqdm import tqdm
from torchvision import transforms
from torchvision.io import write_video
from einops import rearrange
import torch.distributed as dist
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pipeline import (
CausalDiffusionInferencePipeline,
CausalInferencePipeline
)
from utils.dataset import TextDataset, TextImagePairDataset
from utils.misc import set_seed
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to the config file")
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
parser.add_argument("--output_folder", type=str, help="Output folder")
parser.add_argument("--num_output_frames", type=int, default=21,
help="Number of overlap frames between sliding windows")
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
parser.add_argument("--save_with_index", action="store_true",
help="Whether to save the video using the index or prompt as the filename")
args = parser.parse_args()
# Initialize distributed inference
if "LOCAL_RANK" in os.environ:
dist.init_process_group(backend='nccl')
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
world_size = dist.get_world_size()
set_seed(args.seed + local_rank)
else:
device = torch.device("cuda")
local_rank = 0
world_size = 1
set_seed(args.seed)
torch.set_grad_enabled(False)
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
# Initialize pipeline
if hasattr(config, 'denoising_step_list'):
# Few-step inference
pipeline = CausalInferencePipeline(config, device=device)
else:
# Multi-step diffusion inference
pipeline = CausalDiffusionInferencePipeline(config, device=device)
if args.checkpoint_path:
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
pipeline.generator.load_state_dict(state_dict['generator' if not args.use_ema else 'generator_ema'])
pipeline = pipeline.to(device=device, dtype=torch.bfloat16)
# Create dataset
if args.i2v:
assert not dist.is_initialized(), "I2V does not support distributed inference yet"
transform = transforms.Compose([
transforms.Resize((480, 832)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = TextImagePairDataset(args.data_path, transform=transform)
else:
dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
num_prompts = len(dataset)
print(f"Number of prompts: {num_prompts}")
if dist.is_initialized():
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
# Create output directory (only on main process to avoid race conditions)
if local_rank == 0:
os.makedirs(args.output_folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
def encode(self, videos: torch.Tensor) -> torch.Tensor:
device, dtype = videos[0].device, videos[0].dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
output = [
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
for u in videos
]
output = torch.stack(output, dim=0)
return output
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
idx = batch_data['idx'].item()
# For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
# Unpack the batch data for convenience
if isinstance(batch_data, dict):
batch = batch_data
elif isinstance(batch_data, list):
batch = batch_data[0] # First (and only) item in the batch
all_video = []
num_generated_frames = 0 # Number of generated (latent) frames
if args.i2v:
# For image-to-video, batch contains image and caption
prompt = batch['prompts'][0] # Get caption from batch
prompts = [prompt] * args.num_samples
# Process the image
image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
# Encode the input image as the first latent
initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1)
sampled_noise = torch.randn(
[args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
)
else:
# For text-to-video, batch is just the text prompt
prompt = batch['prompts'][0]
extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
if extended_prompt is not None:
prompts = [extended_prompt] * args.num_samples
else:
prompts = [prompt] * args.num_samples
initial_latent = None
sampled_noise = torch.randn(
[args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
)
# Generate 81 frames
video, latents = pipeline.inference(
noise=sampled_noise,
text_prompts=prompts,
return_latents=True,
initial_latent=initial_latent,
)
current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
all_video.append(current_video)
num_generated_frames += latents.shape[1]
# Final output video
video = 255.0 * torch.cat(all_video, dim=1)
# Clear VAE cache
pipeline.vae.model.clear_cache()
# Save the video if the current prompt is not a dummy prompt
if idx < num_prompts:
model = "regular" if not args.use_ema else "ema"
for seed_idx in range(args.num_samples):
# All processes save their videos
if args.save_with_index:
output_path = os.path.join(args.output_folder, f'{idx}-{seed_idx}_{model}.mp4')
else:
output_path = os.path.join(args.output_folder, f'{prompt[:100]}-{seed_idx}.mp4')
write_video(output_path, video[seed_idx], fps=16)
================================================
FILE: model/__init__.py
================================================
from .diffusion import CausalDiffusion
from .causvid import CausVid
from .dmd import DMD
from .gan import GAN
from .sid import SiD
from .ode_regression import ODERegression
__all__ = [
"CausalDiffusion",
"CausVid",
"DMD",
"GAN",
"SiD",
"ODERegression"
]
================================================
FILE: model/base.py
================================================
from typing import Tuple
from einops import rearrange
from torch import nn
import torch.distributed as dist
import torch
from pipeline import SelfForcingTrainingPipeline, BidirectionalTrainingPipeline
from utils.loss import get_denoising_loss
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper, WanCLIPEncoder
class BaseModel(nn.Module):
def __init__(self, args, device):
super().__init__()
self.is_causal = args.generator_type == "causal"
self.i2v = args.i2v
self._initialize_models(args, device)
self.device = device
self.args = args
self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32
if hasattr(args, "denoising_step_list"):
self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long)
if args.warp_denoising_step:
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
def _initialize_models(self, args, device):
self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-14B")
self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-14B")
self.generator_name = getattr(args, "generator_name", "Wan2.1-T2V-14B")
self.generator = WanDiffusionWrapper(
**getattr(args, "model_kwargs", {}),
model_name=self.generator_name,
is_causal=self.is_causal
)
self.generator.model.requires_grad_(True)
self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False)
self.real_score.model.requires_grad_(False)
self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False)
self.fake_score.model.requires_grad_(True)
self.text_encoder = WanTextEncoder(model_name=self.generator_name)
self.text_encoder.requires_grad_(False)
self.vae = WanVAEWrapper(model_name=self.generator_name)
self.vae.requires_grad_(False)
if self.i2v:
self.image_encoder = WanCLIPEncoder(model_name=self.generator_name)
self.image_encoder.requires_grad_(False)
self.scheduler = self.generator.get_scheduler()
self.scheduler.timesteps = self.scheduler.timesteps.to(device)
def _get_timestep(
self,
min_timestep: int,
max_timestep: int,
batch_size: int,
num_frame: int,
num_frame_per_block: int,
uniform_timestep: bool = False
) -> torch.Tensor:
"""
Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
- If uniform_timestep, it will use the same timestep for all frames.
- If not uniform_timestep, it will use a different timestep for each block.
"""
if uniform_timestep:
timestep = torch.randint(
min_timestep,
max_timestep,
[batch_size, 1],
device=self.device,
dtype=torch.long
).repeat(1, num_frame)
return timestep
else:
timestep = torch.randint(
min_timestep,
max_timestep,
[batch_size, num_frame],
device=self.device,
dtype=torch.long
)
# make the noise level the same within every block
if self.independent_first_frame:
# the first frame is always kept the same
timestep_from_second = timestep[:, 1:]
timestep_from_second = timestep_from_second.reshape(
timestep_from_second.shape[0], -1, num_frame_per_block)
timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1]
timestep_from_second = timestep_from_second.reshape(
timestep_from_second.shape[0], -1)
timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1)
else:
timestep = timestep.reshape(
timestep.shape[0], -1, num_frame_per_block)
timestep[:, :, 1:] = timestep[:, :, 0:1]
timestep = timestep.reshape(timestep.shape[0], -1)
return timestep
class SelfForcingModel(BaseModel):
def __init__(self, args, device):
super().__init__(args, device)
self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)()
def _run_generator(
self,
image_or_video_shape,
conditional_dict: dict,
initial_latent: torch.tensor = None,
clip_fea: torch.Tensor = None,
y: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Optionally simulate the generator's input from noise using backward simulation
and then run the generator for one-step.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
- initial_latent: a tensor containing the initial latents [B, F, C, H, W].
Output:
- pred_image: a tensor with shape [B, F, C, H, W].
- denoised_timestep: an integer
"""
# Step 1: Sample noise and backward simulate the generator's input
assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled"
if initial_latent is not None:
conditional_dict["initial_latent"] = initial_latent
if self.args.i2v:
noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]]
else:
noise_shape = image_or_video_shape.copy()
# During training, the number of generated frames should be uniformly sampled from
# [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block
min_num_frames = 20 if self.args.independent_first_frame else 21
max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames
assert max_num_frames % self.num_frame_per_block == 0
assert min_num_frames % self.num_frame_per_block == 0
max_num_blocks = max_num_frames // self.num_frame_per_block
min_num_blocks = min_num_frames // self.num_frame_per_block
num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device)
dist.broadcast(num_generated_blocks, src=0)
num_generated_blocks = num_generated_blocks.item()
num_generated_frames = num_generated_blocks * self.num_frame_per_block
if self.args.independent_first_frame and initial_latent is None:
num_generated_frames += 1
min_num_frames += 1
# Sync num_generated_frames across all processes
noise_shape[1] = num_generated_frames
pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(
noise=torch.randn(noise_shape,
device=self.device, dtype=self.dtype),
clip_fea=clip_fea,
y=y,
**conditional_dict
)
# Slice last 21 frames
if pred_image_or_video.shape[1] > 21:
with torch.no_grad():
# Reencode to get image latent
latent_to_decode = pred_image_or_video[:, :-20, ...]
# Deccode to video
pixels = self.vae.decode_to_pixel(latent_to_decode)
frame = pixels[:, -1:, ...].to(self.dtype)
frame = rearrange(frame, "b t c h w -> b c t h w")
# Encode frame to get image latent
image_latent = self.vae.encode_to_latent(frame).to(self.dtype)
pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1)
else:
pred_image_or_video_last_21 = pred_image_or_video
if num_generated_frames != min_num_frames:
# Currently, we do not use gradient for the first chunk, since it contains image latents
gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool)
if self.args.independent_first_frame:
gradient_mask[:, :1] = False
else:
gradient_mask[:, :self.num_frame_per_block] = False
else:
gradient_mask = None
pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype)
return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to
def _consistency_backward_simulation(
self,
noise: torch.Tensor,
clip_fea: torch.Tensor,
y: torch.Tensor,
**conditional_dict: dict
) -> torch.Tensor:
"""
Simulate the generator's input from noise to avoid training/inference mismatch.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
Input:
- noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
Output:
- output: a tensor with shape [B, T, F, C, H, W].
T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
represents the x0 prediction at each timestep.
"""
if self.inference_pipeline is None:
self._initialize_inference_pipeline()
return self.inference_pipeline.inference_with_trajectory(
noise=noise, clip_fea=clip_fea, y=y, **conditional_dict
)
def _initialize_inference_pipeline(self):
"""
Lazy initialize the inference pipeline during the first backward simulation run.
Here we encapsulate the inference code with a model-dependent outside function.
We pass our FSDP-wrapped modules into the pipeline to save memory.
"""
if self.is_causal:
self.inference_pipeline = SelfForcingTrainingPipeline(
model_name=self.generator_name,
denoising_step_list=self.denoising_step_list,
scheduler=self.scheduler,
generator=self.generator,
num_frame_per_block=self.num_frame_per_block,
independent_first_frame=self.args.independent_first_frame,
same_step_across_blocks=self.args.same_step_across_blocks,
last_step_only=self.args.last_step_only,
num_max_frames=self.num_training_frames,
context_noise=self.args.context_noise
)
else:
self.inference_pipeline = BidirectionalTrainingPipeline(
model_name=self.generator_name,
denoising_step_list=self.denoising_step_list,
scheduler=self.scheduler,
generator=self.generator,
)
================================================
FILE: model/causvid.py
================================================
import torch.nn.functional as F
from typing import Tuple
import torch
from model.base import BaseModel
class CausVid(BaseModel):
def __init__(self, args, device):
"""
Initialize the DMD (Distribution Matching Distillation) module.
This class is self-contained and compute generator and fake score losses
in the forward pass.
"""
super().__init__(args, device)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.num_training_frames = getattr(args, "num_training_frames", 21)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
self.fake_score.enable_gradient_checkpointing()
# Step 2: Initialize all dmd hyperparameters
self.num_train_timestep = args.num_train_timestep
self.min_step = int(0.02 * self.num_train_timestep)
self.max_step = int(0.98 * self.num_train_timestep)
if hasattr(args, "real_guidance_scale"):
self.real_guidance_scale = args.real_guidance_scale
self.fake_guidance_scale = args.fake_guidance_scale
else:
self.real_guidance_scale = args.guidance_scale
self.fake_guidance_scale = 0.0
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
self.teacher_forcing = getattr(args, "teacher_forcing", False)
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
else:
self.scheduler.alphas_cumprod = None
def _compute_kl_grad(
self, noisy_image_or_video: torch.Tensor,
estimated_clean_image_or_video: torch.Tensor,
timestep: torch.Tensor,
conditional_dict: dict, unconditional_dict: dict,
normalization: bool = True
) -> Tuple[torch.Tensor, dict]:
"""
Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
Input:
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
- estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
- timestep: a tensor with shape [B, F] containing the randomly generated timestep.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- normalization: a boolean indicating whether to normalize the gradient.
Output:
- kl_grad: a tensor representing the KL grad.
- kl_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Compute the fake score
_, pred_fake_image_cond = self.fake_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=conditional_dict,
timestep=timestep
)
if self.fake_guidance_scale != 0.0:
_, pred_fake_image_uncond = self.fake_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=unconditional_dict,
timestep=timestep
)
pred_fake_image = pred_fake_image_cond + (
pred_fake_image_cond - pred_fake_image_uncond
) * self.fake_guidance_scale
else:
pred_fake_image = pred_fake_image_cond
# Step 2: Compute the real score
# We compute the conditional and unconditional prediction
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
_, pred_real_image_cond = self.real_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=conditional_dict,
timestep=timestep
)
_, pred_real_image_uncond = self.real_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=unconditional_dict,
timestep=timestep
)
pred_real_image = pred_real_image_cond + (
pred_real_image_cond - pred_real_image_uncond
) * self.real_guidance_scale
# Step 3: Compute the DMD gradient (DMD paper eq. 7).
grad = (pred_fake_image - pred_real_image)
# TODO: Change the normalizer for causal teacher
if normalization:
# Step 4: Gradient normalization (DMD paper eq. 8).
p_real = (estimated_clean_image_or_video - pred_real_image)
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
grad = grad / normalizer
grad = torch.nan_to_num(grad)
return grad, {
"dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
"timestep": timestep.detach()
}
def compute_distribution_matching_loss(
self,
image_or_video: torch.Tensor,
conditional_dict: dict,
unconditional_dict: dict,
gradient_mask: torch.Tensor = None,
) -> Tuple[torch.Tensor, dict]:
"""
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
Input:
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
Output:
- dmd_loss: a scalar tensor representing the DMD loss.
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
"""
original_latent = image_or_video
batch_size, num_frame = image_or_video.shape[:2]
with torch.no_grad():
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
timestep = self._get_timestep(
0,
self.num_train_timestep,
batch_size,
num_frame,
self.num_frame_per_block,
uniform_timestep=True
)
if self.timestep_shift > 1:
timestep = self.timestep_shift * \
(timestep / 1000) / \
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
timestep = timestep.clamp(self.min_step, self.max_step)
noise = torch.randn_like(image_or_video)
noisy_latent = self.scheduler.add_noise(
image_or_video.flatten(0, 1),
noise.flatten(0, 1),
timestep.flatten(0, 1)
).detach().unflatten(0, (batch_size, num_frame))
# Step 2: Compute the KL grad
grad, dmd_log_dict = self._compute_kl_grad(
noisy_image_or_video=noisy_latent,
estimated_clean_image_or_video=original_latent,
timestep=timestep,
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict
)
if gradient_mask is not None:
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
)[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
else:
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
), (original_latent.double() - grad.double()).detach(), reduction="mean")
return dmd_loss, dmd_log_dict
def _run_generator(
self,
image_or_video_shape,
conditional_dict: dict,
clean_latent: torch.tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Optionally simulate the generator's input from noise using backward simulation
and then run the generator for one-step.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
- initial_latent: a tensor containing the initial latents [B, F, C, H, W].
Output:
- pred_image: a tensor with shape [B, F, C, H, W].
"""
simulated_noisy_input = []
for timestep in self.denoising_step_list:
noise = torch.randn(
image_or_video_shape, device=self.device, dtype=self.dtype)
noisy_timestep = timestep * torch.ones(
image_or_video_shape[:2], device=self.device, dtype=torch.long)
if timestep != 0:
noisy_image = self.scheduler.add_noise(
clean_latent.flatten(0, 1),
noise.flatten(0, 1),
noisy_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
else:
noisy_image = clean_latent
simulated_noisy_input.append(noisy_image)
simulated_noisy_input = torch.stack(simulated_noisy_input, dim=1)
# Step 2: Randomly sample a timestep and pick the corresponding input
index = self._get_timestep(
0,
len(self.denoising_step_list),
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=False
)
# select the corresponding timestep's noisy input from the stacked tensor [B, T, F, C, H, W]
noisy_input = torch.gather(
simulated_noisy_input, dim=1,
index=index.reshape(index.shape[0], 1, index.shape[1], 1, 1, 1).expand(
-1, -1, -1, *image_or_video_shape[2:]).to(self.device)
).squeeze(1)
timestep = self.denoising_step_list[index].to(self.device)
_, pred_image_or_video = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep,
clean_x=clean_latent if self.teacher_forcing else None,
)
gradient_mask = None # timestep != 0
pred_image_or_video = pred_image_or_video.type_as(noisy_input)
return pred_image_or_video, gradient_mask
def generator_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
initial_latent: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and compute the DMD loss.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Run generator on backward simulated noisy input
pred_image, gradient_mask = self._run_generator(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
clean_latent=clean_latent
)
# Step 2: Compute the DMD loss
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
image_or_video=pred_image,
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict,
gradient_mask=gradient_mask
)
# Step 3: TODO: Implement the GAN loss
return dmd_loss, dmd_log_dict
def critic_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
initial_latent: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and train the critic with generated samples.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Run generator on backward simulated noisy input
with torch.no_grad():
generated_image, _ = self._run_generator(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
clean_latent=clean_latent
)
# Step 2: Compute the fake prediction
critic_timestep = self._get_timestep(
0,
self.num_train_timestep,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=True
)
if self.timestep_shift > 1:
critic_timestep = self.timestep_shift * \
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
critic_noise = torch.randn_like(generated_image)
noisy_generated_image = self.scheduler.add_noise(
generated_image.flatten(0, 1),
critic_noise.flatten(0, 1),
critic_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
_, pred_fake_image = self.fake_score(
noisy_image_or_video=noisy_generated_image,
conditional_dict=conditional_dict,
timestep=critic_timestep
)
# Step 3: Compute the denoising loss for the fake critic
if self.args.denoising_loss_type == "flow":
from utils.wan_wrapper import WanDiffusionWrapper
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
scheduler=self.scheduler,
x0_pred=pred_fake_image.flatten(0, 1),
xt=noisy_generated_image.flatten(0, 1),
timestep=critic_timestep.flatten(0, 1)
)
pred_fake_noise = None
else:
flow_pred = None
pred_fake_noise = self.scheduler.convert_x0_to_noise(
x0=pred_fake_image.flatten(0, 1),
xt=noisy_generated_image.flatten(0, 1),
timestep=critic_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
denoising_loss = self.denoising_loss_func(
x=generated_image.flatten(0, 1),
x_pred=pred_fake_image.flatten(0, 1),
noise=critic_noise.flatten(0, 1),
noise_pred=pred_fake_noise,
alphas_cumprod=self.scheduler.alphas_cumprod,
timestep=critic_timestep.flatten(0, 1),
flow_pred=flow_pred
)
# Step 4: TODO: Compute the GAN loss
# Step 5: Debugging Log
critic_log_dict = {
"critic_timestep": critic_timestep.detach()
}
return denoising_loss, critic_log_dict
================================================
FILE: model/diffusion.py
================================================
from typing import Tuple
import torch
from model.base import BaseModel
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
class CausalDiffusion(BaseModel):
def __init__(self, args, device):
"""
Initialize the Diffusion loss module.
"""
super().__init__(args, device)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
# Step 2: Initialize all hyperparameters
self.num_train_timestep = args.num_train_timestep
self.min_step = int(0.02 * self.num_train_timestep)
self.max_step = int(0.98 * self.num_train_timestep)
self.guidance_scale = args.guidance_scale
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
self.teacher_forcing = getattr(args, "teacher_forcing", False)
# Noise augmentation in teacher forcing, we add small noise to clean context latents
self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0)
def _initialize_models(self, args):
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
self.generator.model.requires_grad_(True)
self.text_encoder = WanTextEncoder()
self.text_encoder.requires_grad_(False)
self.vae = WanVAEWrapper()
self.vae.requires_grad_(False)
def generator_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
initial_latent: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and compute the DMD loss.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
"""
noise = torch.randn_like(clean_latent)
batch_size, num_frame = image_or_video_shape[:2]
# Step 2: Randomly sample a timestep and add noise to denoiser inputs
index = self._get_timestep(
0,
self.scheduler.num_train_timesteps,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=False
)
timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device)
noisy_latents = self.scheduler.add_noise(
clean_latent.flatten(0, 1),
noise.flatten(0, 1),
timestep.flatten(0, 1)
).unflatten(0, (batch_size, num_frame))
training_target = self.scheduler.training_target(clean_latent, noise, timestep)
# Step 3: Noise augmentation, also add small noise to clean context latents
if self.noise_augmentation_max_timestep > 0:
index_clean_aug = self._get_timestep(
0,
self.noise_augmentation_max_timestep,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=False
)
timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device)
clean_latent_aug = self.scheduler.add_noise(
clean_latent.flatten(0, 1),
noise.flatten(0, 1),
timestep_clean_aug.flatten(0, 1)
).unflatten(0, (batch_size, num_frame))
else:
clean_latent_aug = clean_latent
timestep_clean_aug = None
# Compute loss
flow_pred, x0_pred = self.generator(
noisy_image_or_video=noisy_latents,
conditional_dict=conditional_dict,
timestep=timestep,
clean_x=clean_latent_aug if self.teacher_forcing else None,
aug_t=timestep_clean_aug if self.teacher_forcing else None
)
# loss = torch.nn.functional.mse_loss(flow_pred.float(), training_target.float())
loss = torch.nn.functional.mse_loss(
flow_pred.float(), training_target.float(), reduction='none'
).mean(dim=(2, 3, 4))
loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame))
loss = loss.mean()
log_dict = {
"x0": clean_latent.detach(),
"x0_pred": x0_pred.detach()
}
return loss, log_dict
================================================
FILE: model/dmd.py
================================================
from pipeline import SelfForcingTrainingPipeline
import torch.nn.functional as F
from typing import Optional, Tuple
import torch
from model.base import SelfForcingModel
class DMD(SelfForcingModel):
def __init__(self, args, device):
"""
Initialize the DMD (Distribution Matching Distillation) module.
This class is self-contained and compute generator and fake score losses
in the forward pass.
"""
super().__init__(args, device)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
self.num_training_frames = getattr(args, "num_training_frames", 21)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
self.fake_score.enable_gradient_checkpointing()
# this will be init later with fsdp-wrapped modules
self.inference_pipeline: SelfForcingTrainingPipeline = None
# Step 2: Initialize all dmd hyperparameters
self.num_train_timestep = args.num_train_timestep
self.min_step = int(0.02 * self.num_train_timestep)
self.max_step = int(0.98 * self.num_train_timestep)
if hasattr(args, "real_guidance_scale"):
self.real_guidance_scale = args.real_guidance_scale
self.fake_guidance_scale = args.fake_guidance_scale
else:
self.real_guidance_scale = args.guidance_scale
self.fake_guidance_scale = 0.0
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
self.ts_schedule = getattr(args, "ts_schedule", True)
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
self.min_score_timestep = getattr(args, "min_score_timestep", 0)
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
else:
self.scheduler.alphas_cumprod = None
def _compute_kl_grad(
self, noisy_image_or_video: torch.Tensor,
estimated_clean_image_or_video: torch.Tensor,
timestep: torch.Tensor,
conditional_dict: dict, unconditional_dict: dict,
normalization: bool = True,
clip_fea = None,
y = None
) -> Tuple[torch.Tensor, dict]:
"""
Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
Input:
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
- estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
- timestep: a tensor with shape [B, F] containing the randomly generated timestep.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- normalization: a boolean indicating whether to normalize the gradient.
Output:
- kl_grad: a tensor representing the KL grad.
- kl_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Compute the fake score
_, pred_fake_image_cond = self.fake_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=conditional_dict,
timestep=timestep,
clip_fea=clip_fea,
y=y
)
if self.fake_guidance_scale != 0.0:
_, pred_fake_image_uncond = self.fake_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=unconditional_dict,
timestep=timestep,
clip_fea=clip_fea,
y=y
)
pred_fake_image = pred_fake_image_cond + (
pred_fake_image_cond - pred_fake_image_uncond
) * self.fake_guidance_scale
else:
pred_fake_image = pred_fake_image_cond
# Step 2: Compute the real score
# We compute the conditional and unconditional prediction
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
_, pred_real_image_cond = self.real_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=conditional_dict,
timestep=timestep,
clip_fea=clip_fea,
y=y
)
_, pred_real_image_uncond = self.real_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=unconditional_dict,
timestep=timestep,
clip_fea=clip_fea,
y=y
)
pred_real_image = pred_real_image_cond + (
pred_real_image_cond - pred_real_image_uncond
) * self.real_guidance_scale
# Step 3: Compute the DMD gradient (DMD paper eq. 7).
grad = (pred_fake_image - pred_real_image)
# TODO: Change the normalizer for causal teacher
if normalization:
# Step 4: Gradient normalization (DMD paper eq. 8).
p_real = (estimated_clean_image_or_video - pred_real_image)
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
grad = grad / normalizer
grad = torch.nan_to_num(grad)
return grad, {
"dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
"timestep": timestep.detach()
}
def compute_distribution_matching_loss(
self,
image_or_video: torch.Tensor,
conditional_dict: dict,
unconditional_dict: dict,
gradient_mask: Optional[torch.Tensor] = None,
denoised_timestep_from: int = 0,
denoised_timestep_to: int = 0,
clip_fea: torch.Tensor = None,
y: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
Input:
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
Output:
- dmd_loss: a scalar tensor representing the DMD loss.
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
"""
original_latent = image_or_video
batch_size, num_frame = image_or_video.shape[:2]
with torch.no_grad():
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
timestep = self._get_timestep(
min_timestep,
max_timestep,
batch_size,
num_frame,
self.num_frame_per_block,
uniform_timestep=True
)
# TODO:should we change it to `timestep = self.scheduler.timesteps[timestep]`?
if self.timestep_shift > 1:
timestep = self.timestep_shift * \
(timestep / 1000) / \
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
timestep = timestep.clamp(self.min_step, self.max_step)
noise = torch.randn_like(image_or_video)
noisy_latent = self.scheduler.add_noise(
image_or_video.flatten(0, 1),
noise.flatten(0, 1),
timestep.flatten(0, 1)
).detach().unflatten(0, (batch_size, num_frame))
# Step 2: Compute the KL grad
grad, dmd_log_dict = self._compute_kl_grad(
noisy_image_or_video=noisy_latent,
estimated_clean_image_or_video=original_latent,
timestep=timestep,
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict,
clip_fea=clip_fea,
y=y
)
if gradient_mask is not None:
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
)[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
else:
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
), (original_latent.double() - grad.double()).detach(), reduction="mean")
return dmd_loss, dmd_log_dict
def generator_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
initial_latent: torch.Tensor = None,
clip_fea: torch.Tensor = None,
y: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and compute the DMD loss.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Unroll generator to obtain fake videos
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
initial_latent=initial_latent,
clip_fea=clip_fea,
y=y
)
# Step 2: Compute the DMD loss
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
image_or_video=pred_image,
conditional_dict=conditional_dict,
unconditional_dict=unconditional_dict,
gradient_mask=gradient_mask,
denoised_timestep_from=denoised_timestep_from,
denoised_timestep_to=denoised_timestep_to,
clip_fea=clip_fea,
y=y
)
del pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to
return dmd_loss, dmd_log_dict
def critic_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
initial_latent: torch.Tensor = None,
clip_fea: torch.Tensor = None,
y: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and train the critic with generated samples.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Run generator on backward simulated noisy input
with torch.no_grad():
generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
initial_latent=initial_latent,
clip_fea=clip_fea,
y=y
)
# Step 2: Compute the fake prediction
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
critic_timestep = self._get_timestep(
min_timestep,
max_timestep,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=True
)
if self.timestep_shift > 1:
critic_timestep = self.timestep_shift * \
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
critic_noise = torch.randn_like(generated_image)
noisy_generated_image = self.scheduler.add_noise(
generated_image.flatten(0, 1),
critic_noise.flatten(0, 1),
critic_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
_, pred_fake_image = self.fake_score(
noisy_image_or_video=noisy_generated_image,
conditional_dict=conditional_dict,
timestep=critic_timestep,
clip_fea=clip_fea,
y=y
)
# Step 3: Compute the denoising loss for the fake critic
if self.args.denoising_loss_type == "flow":
from utils.wan_wrapper import WanDiffusionWrapper
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
scheduler=self.scheduler,
x0_pred=pred_fake_image.flatten(0, 1),
xt=noisy_generated_image.flatten(0, 1),
timestep=critic_timestep.flatten(0, 1)
)
pred_fake_noise = None
else:
flow_pred = None
pred_fake_noise = self.scheduler.convert_x0_to_noise(
x0=pred_fake_image.flatten(0, 1),
xt=noisy_generated_image.flatten(0, 1),
timestep=critic_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
denoising_loss = self.denoising_loss_func(
x=generated_image.flatten(0, 1),
x_pred=pred_fake_image.flatten(0, 1),
noise=critic_noise.flatten(0, 1),
noise_pred=pred_fake_noise,
alphas_cumprod=self.scheduler.alphas_cumprod,
timestep=critic_timestep.flatten(0, 1),
flow_pred=flow_pred
)
# Step 5: Debugging Log
critic_log_dict = {
"critic_timestep": critic_timestep.detach()
}
return denoising_loss, critic_log_dict
================================================
FILE: model/gan.py
================================================
import copy
from pipeline import SelfForcingTrainingPipeline
import torch.nn.functional as F
from typing import Tuple
import torch
from model.base import SelfForcingModel
class GAN(SelfForcingModel):
def __init__(self, args, device):
"""
Initialize the GAN module.
This class is self-contained and compute generator and fake score losses
in the forward pass.
"""
super().__init__(args, device)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
self.concat_time_embeddings = getattr(args, "concat_time_embeddings", False)
self.num_class = args.num_class
self.relativistic_discriminator = getattr(args, "relativistic_discriminator", False)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.fake_score.adding_cls_branch(
atten_dim=1536, num_class=args.num_class, time_embed_dim=1536 if self.concat_time_embeddings else 0)
self.fake_score.model.requires_grad_(True)
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
self.fake_score.enable_gradient_checkpointing()
# this will be init later with fsdp-wrapped modules
self.inference_pipeline: SelfForcingTrainingPipeline = None
# Step 2: Initialize all dmd hyperparameters
self.num_train_timestep = args.num_train_timestep
self.min_step = int(0.02 * self.num_train_timestep)
self.max_step = int(0.98 * self.num_train_timestep)
if hasattr(args, "real_guidance_scale"):
self.real_guidance_scale = args.real_guidance_scale
self.fake_guidance_scale = args.fake_guidance_scale
else:
self.real_guidance_scale = args.guidance_scale
self.fake_guidance_scale = 0.0
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
self.critic_timestep_shift = getattr(args, "critic_timestep_shift", self.timestep_shift)
self.ts_schedule = getattr(args, "ts_schedule", True)
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
self.min_score_timestep = getattr(args, "min_score_timestep", 0)
self.gan_g_weight = getattr(args, "gan_g_weight", 1e-2)
self.gan_d_weight = getattr(args, "gan_d_weight", 1e-2)
self.r1_weight = getattr(args, "r1_weight", 0.0)
self.r2_weight = getattr(args, "r2_weight", 0.0)
self.r1_sigma = getattr(args, "r1_sigma", 0.01)
self.r2_sigma = getattr(args, "r2_sigma", 0.01)
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
else:
self.scheduler.alphas_cumprod = None
def _run_cls_pred_branch(self,
noisy_image_or_video: torch.Tensor,
conditional_dict: dict,
timestep: torch.Tensor) -> torch.Tensor:
"""
Run the classifier prediction branch on the generated image or video.
Input:
- image_or_video: a tensor with shape [B, F, C, H, W].
Output:
- cls_pred: a tensor with shape [B, 1, 1, 1, 1] representing the feature map for classification.
"""
_, _, noisy_logit = self.fake_score(
noisy_image_or_video=noisy_image_or_video,
conditional_dict=conditional_dict,
timestep=timestep,
classify_mode=True,
concat_time_embeddings=self.concat_time_embeddings
)
return noisy_logit
def generator_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
initial_latent: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and compute the DMD loss.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
"""
# Step 1: Unroll generator to obtain fake videos
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
image_or_video_shape=image_or_video_shape,
conditional_dict=conditional_dict,
initial_latent=initial_latent
)
# Step 2: Get timestep and add noise to generated/real latents
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
critic_timestep = self._get_timestep(
min_timestep,
max_timestep,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=True
)
if self.critic_timestep_shift > 1:
critic_timestep = self.critic_timestep_shift * \
(critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
critic_noise = torch.randn_like(pred_image)
noisy_fake_latent = self.scheduler.add_noise(
pred_image.flatten(0, 1),
critic_noise.flatten(0, 1),
critic_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
# Step 4: Compute the real GAN discriminator loss
real_image_or_video = clean_latent.clone()
critic_noise = torch.randn_like(real_image_or_video)
noisy_real_latent = self.scheduler.add_noise(
real_image_or_video.flatten(0, 1),
critic_noise.flatten(0, 1),
critic_timestep.flatten(0, 1)
).unflatten(0, image_or_video_shape[:2])
conditional_dict["prompt_embeds"] = torch.concatenate(
(conditional_dict["prompt_embeds"], conditional_dict["prompt_embeds"]), dim=0)
critic_timestep = torch.concatenate((critic_timestep, critic_timestep), dim=0)
noisy_latent = torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0)
_, _, noisy_logit = self.fake_score(
noisy_image_or_video=noisy_latent,
conditional_dict=conditional_dict,
timestep=critic_timestep,
classify_mode=True,
concat_time_embeddings=self.concat_time_embeddings
)
noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
if not self.relativistic_discriminator:
gan_G_loss = F.softplus(-noisy_fake_logit.float()).mean() * self.gan_g_weight
else:
relative_fake_logit = noisy_fake_logit - noisy_real_logit
gan_G_loss = F.softplus(-relative_fake_logit.float()).mean() * self.gan_g_weight
return gan_G_loss
def critic_loss(
self,
image_or_video_shape,
conditional_dict: dict,
unconditional_dict: dict,
clean_latent: torch.Tensor,
real_image_or_video: torch.Tensor,
initial_latent: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and train the critic with generated samples.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator lo
gitextract_tdko9ms3/
├── LICENSE.md
├── README.md
├── configs/
│ ├── default_config.yaml
│ ├── self_forcing_14b_dmd.yaml
│ ├── self_forcing_14b_i2v_dmd.yaml
│ ├── self_forcing_dmd.yaml
│ └── self_forcing_sid.yaml
├── convert_checkpoint.py
├── demo.py
├── demo_utils/
│ ├── constant.py
│ ├── memory.py
│ ├── taehv.py
│ ├── utils.py
│ ├── vae.py
│ ├── vae_block3.py
│ └── vae_torch2trt.py
├── inference.py
├── model/
│ ├── __init__.py
│ ├── base.py
│ ├── causvid.py
│ ├── diffusion.py
│ ├── dmd.py
│ ├── gan.py
│ ├── ode_regression.py
│ └── sid.py
├── pipeline/
│ ├── __init__.py
│ ├── bidirectional_diffusion_inference.py
│ ├── bidirectional_inference.py
│ ├── bidirectional_training.py
│ ├── causal_diffusion_inference.py
│ ├── causal_inference.py
│ └── self_forcing_training.py
├── prompts/
│ ├── MovieGenVideoBench.txt
│ ├── MovieGenVideoBench_extended.txt
│ └── vbench/
│ ├── all_dimension.txt
│ └── all_dimension_longer.txt
├── requirements.txt
├── scripts/
│ ├── compute_vae_latent.py
│ ├── create_lmdb_14b_shards.py
│ ├── create_lmdb_iterative.py
│ └── generate_ode_pairs.py
├── setup.py
├── templates/
│ └── demo.html
├── train.py
├── trainer/
│ ├── __init__.py
│ ├── diffusion.py
│ ├── distillation.py
│ ├── gan.py
│ └── ode.py
├── utils/
│ ├── dataset.py
│ ├── distributed.py
│ ├── lmdb.py
│ ├── loss.py
│ ├── misc.py
│ ├── scheduler.py
│ └── wan_wrapper.py
└── wan/
├── README.md
├── __init__.py
├── configs/
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed/
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── image2video.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── causal_model.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae.py
│ └── xlm_roberta.py
├── text2video.py
└── utils/
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
└── utils.py
SYMBOL INDEX (598 symbols across 55 files)
FILE: convert_checkpoint.py
function main (line 7) | def main():
FILE: demo.py
function initialize_vae_decoder (line 53) | def initialize_vae_decoder(use_taehv=False, use_trt=False):
function tensor_to_base64_frame (line 155) | def tensor_to_base64_frame(frame_tensor):
function frame_sender_worker (line 178) | def frame_sender_worker():
function generate_video_stream (line 230) | def generate_video_stream(prompt, seed, enable_torch_compile=False, enab...
function handle_connect (line 499) | def handle_connect():
function handle_disconnect (line 505) | def handle_disconnect():
function handle_start_generation (line 510) | def handle_start_generation(data):
function handle_stop_generation (line 534) | def handle_stop_generation():
function index (line 551) | def index():
function api_status (line 556) | def api_status():
FILE: demo_utils/memory.py
class DynamicSwapInstaller (line 13) | class DynamicSwapInstaller:
method _install_module (line 15) | def _install_module(module: torch.nn.Module, **kwargs):
method _uninstall_module (line 43) | def _uninstall_module(module: torch.nn.Module):
method install_model (line 49) | def install_model(model: torch.nn.Module, **kwargs):
method uninstall_model (line 55) | def uninstall_model(model: torch.nn.Module):
function fake_diffusers_current_device (line 61) | def fake_diffusers_current_device(model: torch.nn.Module, target_device:...
function get_cuda_free_memory_gb (line 72) | def get_cuda_free_memory_gb(device=None):
function move_model_to_device_with_memory_preservation (line 85) | def move_model_to_device_with_memory_preservation(model, target_device, ...
function offload_model_from_device_for_memory_preservation (line 101) | def offload_model_from_device_for_memory_preservation(model, target_devi...
function unload_complete_models (line 117) | def unload_complete_models(*args):
function load_model_as_complete (line 127) | def load_model_as_complete(model, target_device, unload=True):
FILE: demo_utils/taehv.py
function conv (line 16) | def conv(n_in, n_out, **kwargs):
class Clamp (line 20) | class Clamp(nn.Module):
method forward (line 21) | def forward(self, x):
class MemBlock (line 25) | class MemBlock(nn.Module):
method __init__ (line 26) | def __init__(self, n_in, n_out):
method forward (line 33) | def forward(self, x, past):
class TPool (line 37) | class TPool(nn.Module):
method __init__ (line 38) | def __init__(self, n_f, stride):
method forward (line 43) | def forward(self, x):
class TGrow (line 48) | class TGrow(nn.Module):
method __init__ (line 49) | def __init__(self, n_f, stride):
method forward (line 54) | def forward(self, x):
function apply_model_with_memblocks (line 60) | def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
class TAEHV (line 159) | class TAEHV(nn.Module):
method __init__ (line 163) | def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(...
method patch_tgrow_layers (line 195) | def patch_tgrow_layers(self, sd):
method encode_video (line 210) | def encode_video(self, x, parallel=True, show_progress_bar=True):
method decode_video (line 222) | def decode_video(self, x, parallel=True, show_progress_bar=False):
method forward (line 236) | def forward(self, x):
function main (line 241) | def main():
FILE: demo_utils/utils.py
function min_resize (line 19) | def min_resize(x, m):
function d_resize (line 36) | def d_resize(x, y):
function resize_and_center_crop (line 48) | def resize_and_center_crop(image, target_width, target_height):
function resize_and_center_crop_pytorch (line 66) | def resize_and_center_crop_pytorch(image, target_width, target_height):
function resize_without_crop (line 85) | def resize_without_crop(image, target_width, target_height):
function just_crop (line 94) | def just_crop(image, w, h):
function write_to_json (line 108) | def write_to_json(data, file_path):
function read_from_json (line 116) | def read_from_json(file_path):
function get_active_parameters (line 122) | def get_active_parameters(m):
function cast_training_params (line 126) | def cast_training_params(m, dtype=torch.float32):
function separate_lora_AB (line 135) | def separate_lora_AB(parameters, B_patterns=None):
function set_attr_recursive (line 151) | def set_attr_recursive(obj, attr, value):
function print_tensor_list_size (line 159) | def print_tensor_list_size(tensors):
function batch_mixture (line 180) | def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
function zero_module (line 196) | def zero_module(module):
function supress_lower_channels (line 203) | def supress_lower_channels(m, k, alpha=0.01):
function freeze_module (line 213) | def freeze_module(m):
function get_latest_safetensors (line 221) | def get_latest_safetensors(folder_path):
function generate_random_prompt_from_tags (line 232) | def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=...
function interpolate_numbers (line 239) | def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
function uniform_random_by_intervals (line 246) | def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=Fa...
function soft_append_bcthw (line 255) | def soft_append_bcthw(history, current, overlap=0):
function save_bcthw_as_mp4 (line 269) | def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
function save_bcthw_as_png (line 286) | def save_bcthw_as_png(x, output_filename):
function save_bchw_as_png (line 295) | def save_bchw_as_png(x, output_filename):
function add_tensors_with_padding (line 304) | def add_tensors_with_padding(tensor1, tensor2):
function print_free_mem (line 323) | def print_free_mem():
function print_gpu_parameters (line 333) | def print_gpu_parameters(device, state_dict, log_count=1):
function visualize_txt_as_img (line 348) | def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans...
function blue_mark (line 386) | def blue_mark(x):
function green_mark (line 394) | def green_mark(x):
function frame_mark (line 401) | def frame_mark(x):
function pytorch2numpy (line 411) | def pytorch2numpy(imgs):
function numpy2pytorch (line 422) | def numpy2pytorch(imgs):
function duplicate_prefix_to_suffix (line 429) | def duplicate_prefix_to_suffix(x, count, zero_out=False):
function weighted_mse (line 436) | def weighted_mse(a, b, weight):
function clamped_linear_interpolation (line 440) | def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
function expand_to_dims (line 447) | def expand_to_dims(x, target_dims):
function repeat_to_batch_size (line 451) | def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
function dim5 (line 468) | def dim5(x):
function dim4 (line 472) | def dim4(x):
function dim3 (line 476) | def dim3(x):
function crop_or_pad_yield_mask (line 480) | def crop_or_pad_yield_mask(x, length):
function extend_dim (line 495) | def extend_dim(x, dim, minimal_length, zero_pad=False):
function lazy_positional_encoding (line 513) | def lazy_positional_encoding(t, repeats=None):
function state_dict_offset_merge (line 530) | def state_dict_offset_merge(A, B, C=None):
function state_dict_weighted_merge (line 547) | def state_dict_weighted_merge(state_dicts, weights):
function group_files_by_folder (line 574) | def group_files_by_folder(all_files):
function generate_timestamp (line 587) | def generate_timestamp():
function write_PIL_image_with_png_info (line 595) | def write_PIL_image_with_png_info(image, metadata, path):
function torch_safe_save (line 606) | def torch_safe_save(content, path):
function move_optimizer_to_device (line 612) | def move_optimizer_to_device(optimizer, device):
FILE: demo_utils/vae.py
class ResidualBlock (line 13) | class ResidualBlock(nn.Module):
method __init__ (line 15) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 29) | def forward(self, x, feat_cache_1, feat_cache_2):
class Resample (line 51) | class Resample(nn.Module):
method __init__ (line 53) | def __init__(self, dim, mode):
method forward (line 73) | def forward(self, x, is_first_frame, feat_cache):
method temporal_conv (line 105) | def temporal_conv(self, x, is_first_frame, feat_cache):
method init_weight (line 127) | def init_weight(self, conv):
method init_weight2 (line 139) | def init_weight2(self, conv):
class VAEDecoderWrapperSingle (line 151) | class VAEDecoderWrapperSingle(nn.Module):
method __init__ (line 152) | def __init__(self):
method forward (line 168) | def forward(
class VAEDecoder3d (line 199) | class VAEDecoder3d(nn.Module):
method __init__ (line 200) | def __init__(self,
method forward (line 254) | def forward(
class VAETRTWrapper (line 318) | class VAETRTWrapper():
method __init__ (line 319) | def __init__(self):
method quantize_if_needed (line 376) | def quantize_if_needed(self, t, expected_dtype, scale):
method forward (line 381) | def forward(self, *test_inputs):
FILE: demo_utils/vae_block3.py
class Resample (line 9) | class Resample(nn.Module):
method __init__ (line 11) | def __init__(self, dim, mode):
method forward (line 45) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 106) | def init_weight(self, conv):
method init_weight2 (line 118) | def init_weight2(self, conv):
class VAEDecoderWrapper (line 130) | class VAEDecoderWrapper(nn.Module):
method __init__ (line 131) | def __init__(self):
method forward (line 147) | def forward(
class VAEDecoder3d (line 187) | class VAEDecoder3d(nn.Module):
method __init__ (line 188) | def __init__(self,
method forward (line 242) | def forward(
FILE: demo_utils/vae_torch2trt.py
function set_workspace (line 98) | def set_workspace(config, bytes_):
function set_workspace (line 122) | def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
class VAECalibrator (line 139) | class VAECalibrator(trt.IInt8EntropyCalibrator2):
method __init__ (line 140) | def __init__(self, loader, cache="calibration.cache", max_batches=10):
method get_batch_size (line 151) | def get_batch_size(self):
method getBatchSize (line 154) | def getBatchSize(self):
method get_batch (line 157) | def get_batch(self, names):
method read_calibration_cache (line 202) | def read_calibration_cache(self):
method readCalibrationCache (line 209) | def readCalibrationCache(self):
method write_calibration_cache (line 212) | def write_calibration_cache(self, cache):
method writeCalibrationCache (line 216) | def writeCalibrationCache(self, cache):
FILE: inference.py
function encode (line 98) | def encode(self, videos: torch.Tensor) -> torch.Tensor:
FILE: model/base.py
class BaseModel (line 12) | class BaseModel(nn.Module):
method __init__ (line 13) | def __init__(self, args, device):
method _initialize_models (line 28) | def _initialize_models(self, args, device):
method _get_timestep (line 59) | def _get_timestep(
class SelfForcingModel (line 109) | class SelfForcingModel(BaseModel):
method __init__ (line 110) | def __init__(self, args, device):
method _run_generator (line 114) | def _run_generator(
method _consistency_backward_simulation (line 197) | def _consistency_backward_simulation(
method _initialize_inference_pipeline (line 223) | def _initialize_inference_pipeline(self):
FILE: model/causvid.py
class CausVid (line 8) | class CausVid(BaseModel):
method __init__ (line 9) | def __init__(self, args, device):
method _compute_kl_grad (line 47) | def _compute_kl_grad(
method compute_distribution_matching_loss (line 121) | def compute_distribution_matching_loss(
method _run_generator (line 184) | def _run_generator(
method generator_loss (line 255) | def generator_loss(
method critic_loss (line 296) | def critic_loss(
FILE: model/diffusion.py
class CausalDiffusion (line 8) | class CausalDiffusion(BaseModel):
method __init__ (line 9) | def __init__(self, args, device):
method _initialize_models (line 34) | def _initialize_models(self, args):
method generator_loss (line 44) | def generator_loss(
FILE: model/dmd.py
class DMD (line 9) | class DMD(SelfForcingModel):
method __init__ (line 10) | def __init__(self, args, device):
method _compute_kl_grad (line 54) | def _compute_kl_grad(
method compute_distribution_matching_loss (line 138) | def compute_distribution_matching_loss(
method generator_loss (line 210) | def generator_loss(
method critic_loss (line 259) | def critic_loss(
FILE: model/gan.py
class GAN (line 10) | class GAN(SelfForcingModel):
method __init__ (line 11) | def __init__(self, args, device):
method _run_cls_pred_branch (line 69) | def _run_cls_pred_branch(self,
method generator_loss (line 90) | def generator_loss(
method critic_loss (line 174) | def critic_loss(
FILE: model/ode_regression.py
class ODERegression (line 9) | class ODERegression(BaseModel):
method __init__ (line 10) | def __init__(self, args, device):
method _initialize_models (line 46) | def _initialize_models(self, args):
method _prepare_generator_input (line 57) | def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[...
method generator_loss (line 102) | def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: d...
FILE: model/sid.py
class SiD (line 8) | class SiD(SelfForcingModel):
method __init__ (line 9) | def __init__(self, args, device):
method compute_distribution_matching_loss (line 47) | def compute_distribution_matching_loss(
method generator_loss (line 147) | def generator_loss(
method critic_loss (line 188) | def critic_loss(
FILE: pipeline/bidirectional_diffusion_inference.py
class BidirectionalDiffusionInferencePipeline (line 10) | class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 34) | def inference(
method _initialize_sample_scheduler (line 89) | def _initialize_sample_scheduler(self, noise):
FILE: pipeline/bidirectional_inference.py
class BidirectionalInferencePipeline (line 7) | class BidirectionalInferencePipeline(torch.nn.Module):
method __init__ (line 8) | def __init__(
method inference (line 33) | def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> t...
FILE: pipeline/bidirectional_training.py
class BidirectionalTrainingPipeline (line 9) | class BidirectionalTrainingPipeline(torch.nn.Module):
method __init__ (line 10) | def __init__(
method generate_and_sync_list (line 25) | def generate_and_sync_list(self, num_denoising_steps, device):
method inference_with_trajectory (line 42) | def inference_with_trajectory(self, noise: torch.Tensor, clip_fea, y, ...
FILE: pipeline/causal_diffusion_inference.py
class CausalDiffusionInferencePipeline (line 10) | class CausalDiffusionInferencePipeline(torch.nn.Module):
method __init__ (line 11) | def __init__(
method inference (line 49) | def inference(
method _initialize_kv_cache (line 270) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 300) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
method _initialize_sample_scheduler (line 321) | def _initialize_sample_scheduler(self, noise):
FILE: pipeline/causal_inference.py
class CausalInferencePipeline (line 7) | class CausalInferencePipeline(torch.nn.Module):
method __init__ (line 8) | def __init__(
method inference (line 45) | def inference(
method _initialize_kv_cache (line 271) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 293) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
FILE: pipeline/self_forcing_training.py
class SelfForcingTrainingPipeline (line 8) | class SelfForcingTrainingPipeline:
method __init__ (line 9) | def __init__(self,
method generate_and_sync_list (line 45) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
method inference_with_trajectory (line 64) | def inference_with_trajectory(
method _initialize_kv_cache (line 245) | def _initialize_kv_cache(self, batch_size, dtype, device):
method _initialize_crossattn_cache (line 261) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
FILE: scripts/compute_vae_latent.py
function launch_distributed_job (line 16) | def launch_distributed_job(backend: str = "nccl"):
function video_to_numpy (line 32) | def video_to_numpy(video_path):
function encode (line 42) | def encode(self, videos: torch.Tensor) -> torch.Tensor:
function main (line 55) | def main():
FILE: scripts/create_lmdb_14b_shards.py
function main (line 19) | def main():
FILE: scripts/create_lmdb_iterative.py
function main (line 12) | def main():
FILE: scripts/generate_ode_pairs.py
function init_model (line 13) | def init_model(device):
function main (line 32) | def main():
FILE: train.py
function main (line 9) | def main():
FILE: trainer/diffusion.py
class Trainer (line 17) | class Trainer:
method __init__ (line 18) | def __init__(self, config):
method save (line 140) | def save(self):
method train_one_step (line 163) | def train_one_step(self, batch):
method generate_video (line 235) | def generate_video(self, pipeline, prompts, image=None):
method train (line 248) | def train(self):
FILE: trainer/distillation.py
class Trainer (line 20) | class Trainer:
method __init__ (line 21) | def __init__(self, config):
method save (line 236) | def save(self):
method fwdbwd_one_step (line 263) | def fwdbwd_one_step(self, batch, train_generator):
method generate_video (line 350) | def generate_video(self, pipeline, prompts, image=None):
method train (line 380) | def train(self):
FILE: trainer/gan.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, config):
method save (line 208) | def save(self):
method fwdbwd_one_step (line 235) | def fwdbwd_one_step(self, batch, train_generator):
method generate_video (line 324) | def generate_video(self, pipeline, prompts, image=None):
method train (line 337) | def train(self):
method all_gather_dict (line 457) | def all_gather_dict(self, target_dict):
FILE: trainer/ode.py
class Trainer (line 19) | class Trainer:
method __init__ (line 20) | def __init__(self, config):
method save (line 118) | def save(self):
method train_one_step (line 134) | def train_one_step(self):
method train (line 225) | def train(self):
FILE: utils/dataset.py
class TextDataset (line 13) | class TextDataset(Dataset):
method __init__ (line 14) | def __init__(self, prompt_path, extended_prompt_path=None):
method __len__ (line 25) | def __len__(self):
method __getitem__ (line 28) | def __getitem__(self, idx):
class TextFolderDataset (line 38) | class TextFolderDataset(Dataset):
method __init__ (line 39) | def __init__(self, data_path, max_count=30000):
method __len__ (line 51) | def __len__(self):
method __getitem__ (line 54) | def __getitem__(self, idx):
class ODERegressionLMDBDataset (line 58) | class ODERegressionLMDBDataset(Dataset):
method __init__ (line 59) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 66) | def __len__(self):
method __getitem__ (line 69) | def __getitem__(self, idx):
class ShardingLMDBDataset (line 93) | class ShardingLMDBDataset(Dataset):
method __init__ (line 94) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
method __len__ (line 117) | def __len__(self):
method __getitem__ (line 120) | def __getitem__(self, idx):
class TextImagePairDataset (line 157) | class TextImagePairDataset(Dataset):
method __init__ (line 158) | def __init__(
method __len__ (line 212) | def __len__(self):
method __getitem__ (line 215) | def __getitem__(self, idx):
function cycle (line 247) | def cycle(dl):
FILE: utils/distributed.py
function fsdp_state_dict (line 11) | def fsdp_state_dict(model):
function fsdp_wrap (line 23) | def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, w...
function barrier (line 71) | def barrier():
function launch_distributed_job (line 76) | def launch_distributed_job(backend: str = "nccl"):
class EMA_FSDP (line 92) | class EMA_FSDP:
method __init__ (line 93) | def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
method _init_shadow (line 99) | def _init_shadow(self, fsdp_module):
method update (line 106) | def update(self, fsdp_module):
method state_dict (line 114) | def state_dict(self):
method load_state_dict (line 117) | def load_state_dict(self, sd):
method copy_to (line 120) | def copy_to(self, fsdp_module):
FILE: utils/lmdb.py
function get_array_shape_from_lmdb (line 4) | def get_array_shape_from_lmdb(env, array_name):
function store_arrays_to_lmdb (line 11) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
function process_data_dict (line 30) | def process_data_dict(data_dict, seen_prompts):
function retrieve_row_from_lmdb (line 57) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...
FILE: utils/loss.py
class DenoisingLoss (line 5) | class DenoisingLoss(ABC):
method __call__ (line 7) | def __call__(
class X0PredLoss (line 27) | class X0PredLoss(DenoisingLoss):
method __call__ (line 28) | def __call__(
class VPredLoss (line 38) | class VPredLoss(DenoisingLoss):
method __call__ (line 39) | def __call__(
class NoisePredLoss (line 50) | class NoisePredLoss(DenoisingLoss):
method __call__ (line 51) | def __call__(
class FlowPredLoss (line 61) | class FlowPredLoss(DenoisingLoss):
method __call__ (line 62) | def __call__(
function get_denoising_loss (line 80) | def get_denoising_loss(loss_type: str) -> DenoisingLoss:
FILE: utils/misc.py
function set_seed (line 6) | def set_seed(seed: int, deterministic: bool = False):
function merge_dict_list (line 25) | def merge_dict_list(dict_list):
FILE: utils/scheduler.py
class SchedulerInterface (line 5) | class SchedulerInterface(ABC):
method add_noise (line 12) | def add_noise(
method convert_x0_to_noise (line 26) | def convert_x0_to_noise(
method convert_noise_to_x0 (line 52) | def convert_noise_to_x0(
method convert_velocity_to_x0 (line 77) | def convert_velocity_to_x0(
class FlowMatchScheduler (line 106) | class FlowMatchScheduler():
method __init__ (line 108) | def __init__(self, num_inference_steps=100, num_train_timesteps=1000, ...
method set_timesteps (line 118) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
method step (line 143) | def step(self, model_output, timestep, sample, to_final=False):
method add_noise (line 159) | def add_noise(self, original_samples, noise, timestep):
method training_target (line 178) | def training_target(self, sample, noise, timestep):
method training_weight (line 182) | def training_weight(self, timestep):
FILE: utils/wan_wrapper.py
class WanTextEncoder (line 16) | class WanTextEncoder(torch.nn.Module):
method __init__ (line 17) | def __init__(self, model_name="Wan2.1-T2V-14B") -> None:
method device (line 36) | def device(self):
method forward (line 40) | def forward(self, text_prompts: List[str]) -> dict:
class WanCLIPEncoder (line 56) | class WanCLIPEncoder(torch.nn.Module):
method __init__ (line 57) | def __init__(self, model_name="Wan2.1-T2V-14B"):
method device (line 70) | def device(self):
method forward (line 74) | def forward(self, img):
class WanVAEWrapper (line 81) | class WanVAEWrapper(torch.nn.Module):
method __init__ (line 82) | def __init__(self, model_name="Wan2.1-T2V-14B"):
method encode (line 107) | def encode(self, pixel):
method run_vae_encoder (line 117) | def run_vae_encoder(self, img):
method encode_to_latent (line 149) | def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
method decode_to_pixel (line 165) | def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = Fals...
class WanDiffusionWrapper (line 191) | class WanDiffusionWrapper(torch.nn.Module):
method __init__ (line 192) | def __init__(
method enable_gradient_checkpointing (line 222) | def enable_gradient_checkpointing(self) -> None:
method adding_cls_branch (line 225) | def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_di...
method _convert_flow_pred_to_x0 (line 248) | def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch....
method _convert_x0_to_flow_pred (line 275) | def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: tor...
method forward (line 297) | def forward(
method get_scheduler (line 380) | def get_scheduler(self) -> SchedulerInterface:
method post_init (line 394) | def post_init(self):
FILE: wan/distributed/fsdp.py
function shard_model (line 10) | def shard_model(
FILE: wan/distributed/xdit_context_parallel.py
function pad_freqs (line 12) | def pad_freqs(original_tensor, target_len):
function rope_apply (line 26) | def rope_apply(x, grid_sizes, freqs):
function usp_dit_forward (line 66) | def usp_dit_forward(
function usp_attn_forward (line 149) | def usp_attn_forward(self,
FILE: wan/image2video.py
class WanI2V (line 29) | class WanI2V:
method __init__ (line 31) | def __init__(
method generate (line 129) | def generate(self,
FILE: wan/modules/attention.py
function is_hopper_gpu (line 7) | def is_hopper_gpu():
function flash_attention (line 32) | def flash_attention(
function attention (line 139) | def attention(
FILE: wan/modules/causal_model.py
function causal_rope_apply (line 27) | def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
class CausalWanSelfAttention (line 58) | class CausalWanSelfAttention(nn.Module):
method __init__ (line 60) | def __init__(self,
method forward (line 86) | def forward(
class CausalWanAttentionBlock (line 243) | class CausalWanAttentionBlock(nn.Module):
method __init__ (line 245) | def __init__(self,
method forward (line 283) | def forward(
class CausalHead (line 338) | class CausalHead(nn.Module):
method __init__ (line 340) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 355) | def forward(self, x, e):
class CausalWanModel (line 369) | class CausalWanModel(ModelMixin, ConfigMixin):
method __init__ (line 381) | def __init__(self,
method _set_gradient_checkpointing (line 502) | def _set_gradient_checkpointing(self, module, value=False):
method _prepare_blockwise_causal_attn_mask (line 506) | def _prepare_blockwise_causal_attn_mask(
method _prepare_teacher_forcing_mask (line 564) | def _prepare_teacher_forcing_mask(
method _prepare_blockwise_causal_attn_mask_i2v (line 652) | def _prepare_blockwise_causal_attn_mask_i2v(
method _forward_inference (line 712) | def _forward_inference(
method _forward_train (line 843) | def _forward_train(
method forward (line 1001) | def forward(
method unpatchify (line 1011) | def unpatchify(self, x, grid_sizes):
method init_weights (line 1036) | def init_weights(self):
FILE: wan/modules/clip.py
function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
class QuickGELU (line 41) | class QuickGELU(nn.Module):
method forward (line 43) | def forward(self, x):
class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
method forward (line 49) | def forward(self, x):
class SelfAttention (line 54) | class SelfAttention(nn.Module):
method __init__ (line 56) | def __init__(self,
method forward (line 75) | def forward(self, x):
class SwiGLU (line 95) | class SwiGLU(nn.Module):
method __init__ (line 97) | def __init__(self, dim, mid_dim):
method forward (line 107) | def forward(self, x):
class AttentionBlock (line 113) | class AttentionBlock(nn.Module):
method __init__ (line 115) | def __init__(self,
method forward (line 147) | def forward(self, x):
class AttentionPool (line 157) | class AttentionPool(nn.Module):
method __init__ (line 159) | def __init__(self,
method forward (line 187) | def forward(self, x):
class VisionTransformer (line 210) | class VisionTransformer(nn.Module):
method __init__ (line 212) | def __init__(self,
method forward (line 280) | def forward(self, x, interpolation=False, use_31_block=False):
class XLMRobertaWithHead (line 304) | class XLMRobertaWithHead(XLMRoberta):
method __init__ (line 306) | def __init__(self, **kwargs):
method forward (line 316) | def forward(self, ids):
class XLMRobertaCLIP (line 329) | class XLMRobertaCLIP(nn.Module):
method __init__ (line 331) | def __init__(self,
function _clip (line 408) | def _clip(pretrained=False,
function clip_xlm_roberta_vit_h_14 (line 445) | def clip_xlm_roberta_vit_h_14(
class CLIPModel (line 475) | class CLIPModel(nn.Module):
method __init__ (line 477) | def __init__(self, dtype, device, checkpoint_path):
method visual (line 498) | def visual(self, videos):
FILE: wan/modules/model.py
function sinusoidal_embedding_1d (line 15) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 29) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 40) | def rope_apply(x, grid_sizes, freqs):
class WanRMSNorm (line 70) | class WanRMSNorm(nn.Module):
method __init__ (line 72) | def __init__(self, dim, eps=1e-5):
method forward (line 78) | def forward(self, x):
method _norm (line 85) | def _norm(self, x):
class WanLayerNorm (line 89) | class WanLayerNorm(nn.LayerNorm):
method __init__ (line 91) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 94) | def forward(self, x):
class WanSelfAttention (line 102) | class WanSelfAttention(nn.Module):
method __init__ (line 104) | def __init__(self,
method forward (line 127) | def forward(self, x, seq_lens, grid_sizes, freqs):
class WanT2VCrossAttention (line 159) | class WanT2VCrossAttention(WanSelfAttention):
method forward (line 161) | def forward(self, x, context, context_lens, crossattn_cache=None):
class WanGanCrossAttention (line 197) | class WanGanCrossAttention(WanSelfAttention):
method forward (line 199) | def forward(self, x, context, crossattn_cache=None):
class WanI2VCrossAttention (line 224) | class WanI2VCrossAttention(WanSelfAttention):
method __init__ (line 226) | def __init__(self,
method forward (line 240) | def forward(self, x, context, context_lens):
class WanAttentionBlock (line 275) | class WanAttentionBlock(nn.Module):
method __init__ (line 277) | def __init__(self,
method forward (line 315) | def forward(
class GanAttentionBlock (line 357) | class GanAttentionBlock(nn.Module):
method __init__ (line 359) | def __init__(self,
method forward (line 397) | def forward(
class Head (line 439) | class Head(nn.Module):
method __init__ (line 441) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 456) | def forward(self, x, e):
class MLPProj (line 469) | class MLPProj(torch.nn.Module):
method __init__ (line 471) | def __init__(self, in_dim, out_dim):
method forward (line 479) | def forward(self, image_embeds):
class RegisterTokens (line 484) | class RegisterTokens(nn.Module):
method __init__ (line 485) | def __init__(self, num_registers: int, dim: int):
method forward (line 490) | def forward(self):
method reset_parameters (line 493) | def reset_parameters(self):
class WanModel (line 497) | class WanModel(ModelMixin, ConfigMixin):
method __init__ (line 509) | def __init__(self,
method _set_gradient_checkpointing (line 623) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 626) | def forward(
method _forward (line 637) | def _forward(
method _forward_classify (line 773) | def _forward_classify(
method unpatchify (line 876) | def unpatchify(self, x, grid_sizes, c=None):
method init_weights (line 901) | def init_weights(self):
FILE: wan/modules/t5.py
function fp16_clamp (line 20) | def fp16_clamp(x):
function init_weights (line 27) | def init_weights(m):
class GELU (line 46) | class GELU(nn.Module):
method forward (line 48) | def forward(self, x):
class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
method forward (line 61) | def forward(self, x):
class T5Attention (line 69) | class T5Attention(nn.Module):
method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
method forward (line 136) | def forward(self, x):
class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
method __init__ (line 146) | def __init__(self,
method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
method __init__ (line 180) | def __init__(self,
method forward (line 206) | def forward(self,
class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
method forward (line 233) | def forward(self, lq, lk):
method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
class T5Encoder (line 267) | class T5Encoder(nn.Module):
method __init__ (line 269) | def __init__(self,
method forward (line 303) | def forward(self, ids, mask=None):
class T5Decoder (line 315) | class T5Decoder(nn.Module):
method __init__ (line 317) | def __init__(self,
method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
class T5Model (line 372) | class T5Model(nn.Module):
method __init__ (line 374) | def __init__(self,
method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
function _t5 (line 415) | def _t5(name,
function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
class T5EncoderModel (line 472) | class T5EncoderModel:
method __init__ (line 474) | def __init__(
method __call__ (line 506) | def __call__(self, texts, device):
FILE: wan/modules/tokenizers.py
function basic_clean (line 12) | def basic_clean(text):
function whitespace_clean (line 18) | def whitespace_clean(text):
function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
method __call__ (line 49) | def __call__(self, sequence, **kwargs):
method _clean (line 75) | def _clean(self, text):
FILE: wan/modules/vae.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 28) | def forward(self, x, cache_x=None):
class RMS_norm (line 39) | class RMS_norm(nn.Module):
method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 51) | def forward(self, x):
class Upsample (line 57) | class Upsample(nn.Upsample):
method forward (line 59) | def forward(self, x):
class Resample (line 66) | class Resample(nn.Module):
method __init__ (line 68) | def __init__(self, dim, mode):
method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 162) | def init_weight(self, conv):
method init_weight2 (line 174) | def init_weight2(self, conv):
class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
method __init__ (line 228) | def __init__(self, dim):
method forward (line 240) | def forward(self, x):
class Encoder3d (line 265) | class Encoder3d(nn.Module):
method __init__ (line 267) | def __init__(self,
method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 369) | class Decoder3d(nn.Module):
method __init__ (line 371) | def __init__(self,
method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
function count_conv3d (line 475) | def count_conv3d(model):
class WanVAE_ (line 483) | class WanVAE_(nn.Module):
method __init__ (line 485) | def __init__(self,
method forward (line 511) | def forward(self, x):
method encode (line 517) | def encode(self, x, scale):
method decode (line 545) | def decode(self, z, scale):
method cached_decode (line 571) | def cached_decode(self, z, scale):
method sample (line 595) | def sample(self, imgs, deterministic=False):
method clear_cache (line 602) | def clear_cache(self):
function _video_vae (line 612) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
class WanVAE (line 639) | class WanVAE:
method __init__ (line 641) | def __init__(self,
method encode (line 667) | def encode(self, videos):
method decode (line 677) | def decode(self, zs):
FILE: wan/modules/xlm_roberta.py
class SelfAttention (line 10) | class SelfAttention(nn.Module):
method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
method forward (line 27) | def forward(self, x, mask):
class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
method forward (line 66) | def forward(self, x, mask):
class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
method __init__ (line 81) | def __init__(self,
method forward (line 118) | def forward(self, ids):
function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,
FILE: wan/text2video.py
class WanT2V (line 26) | class WanT2V:
method __init__ (line 28) | def __init__(
method generate (line 110) | def generate(self,
FILE: wan/utils/fm_solvers.py
function get_sampling_sigmas (line 22) | def get_sampling_sigmas(sampling_steps, shift):
function retrieve_timesteps (line 29) | def retrieve_timesteps(
class FlowDPMSolverMultistepScheduler (line 69) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 129) | def __init__(
method step_index (line 202) | def step_index(self):
method begin_index (line 209) | def begin_index(self):
method set_begin_index (line 216) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 226) | def set_timesteps(
method _threshold_sample (line 292) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 330) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 333) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 337) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 341) | def convert_model_output(
method dpm_solver_first_order_update (line 415) | def dpm_solver_first_order_update(
method multistep_dpm_solver_second_order_update (line 486) | def multistep_dpm_solver_second_order_update(
method multistep_dpm_solver_third_order_update (line 596) | def multistep_dpm_solver_third_order_update(
method index_for_timestep (line 679) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 693) | def _init_step_index(self, timestep):
method step (line 706) | def step(
method scale_model_input (line 800) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 815) | def add_noise(
method __len__ (line 856) | def __len__(self):
FILE: wan/utils/fm_solvers_unipc.py
class FlowUniPCMultistepScheduler (line 20) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 77) | def __init__(
method step_index (line 135) | def step_index(self):
method begin_index (line 142) | def begin_index(self):
method set_begin_index (line 149) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 160) | def set_timesteps(
method _threshold_sample (line 230) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 269) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 272) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 276) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 279) | def convert_model_output(
method multistep_uni_p_bh_update (line 350) | def multistep_uni_p_bh_update(
method multistep_uni_c_bh_update (line 486) | def multistep_uni_c_bh_update(
method index_for_timestep (line 628) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 643) | def _init_step_index(self, timestep):
method step (line 655) | def step(self,
method scale_model_input (line 741) | def scale_model_input(self, sample: torch.Tensor, *args,
method add_noise (line 758) | def add_noise(
method __len__ (line 799) | def __len__(self):
FILE: wan/utils/prompt_extend.py
class PromptOutput (line 101) | class PromptOutput(object):
method add_custom_field (line 108) | def add_custom_field(self, key: str, value) -> None:
class PromptExpander (line 112) | class PromptExpander:
method __init__ (line 114) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
method extend_with_img (line 119) | def extend_with_img(self,
method extend (line 128) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method decide_system_prompt (line 131) | def decide_system_prompt(self, tar_lang="ch"):
method __call__ (line 138) | def __call__(self,
class DashScopePromptExpander (line 157) | class DashScopePromptExpander(PromptExpander):
method __init__ (line 159) | def __init__(self,
method extend (line 196) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 232) | def extend_with_img(self,
class QwenPromptExpander (line 300) | class QwenPromptExpander(PromptExpander):
method __init__ (line 309) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
method extend (line 366) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
method extend_with_img (line 397) | def extend_with_img(self,
FILE: wan/utils/qwen_vl_utils.py
function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
function smart_resize (line 54) | def smart_resize(height: int,
function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
function smart_nframes (line 133) | def smart_nframes(
function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
function is_decord_available (line 215) | def is_decord_available() -> bool:
function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
function fetch_video (line 274) | def fetch_video(
function extract_vision_info (line 328) | def extract_vision_info(
function process_vision_info (line 344) | def process_vision_info(
FILE: wan/utils/utils.py
function rand_name (line 14) | def rand_name(length=8, suffix=''):
function cache_video (line 23) | def cache_video(tensor,
function cache_image (line 64) | def cache_image(tensor,
function str2bool (line 94) | def str2bool(v):
Condensed preview — 83 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,163K chars).
[
{
"path": "LICENSE.md",
"chars": 19105,
"preview": "# Attribution-NonCommercial-ShareAlike 4.0 International\n\nCreative Commons Corporation (“Creative Commons”) is not a law"
},
{
"path": "README.md",
"chars": 2966,
"preview": "<p align=\"center\">\n<h1 align=\"center\">Self Forcing Plus</h1>\n\nSelf-Forcing-Plus focuses on step distillation and CFG dis"
},
{
"path": "configs/default_config.yaml",
"chars": 403,
"preview": "independent_first_frame: false\nwarp_denoising_step: false\nweight_decay: 0.01\nsame_step_across_blocks: true\ndiscriminator"
},
{
"path": "configs/self_forcing_14b_dmd.yaml",
"chars": 1402,
"preview": "# generator_ckpt: checkpoints/ode_init.pt\ni2v: false\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: s"
},
{
"path": "configs/self_forcing_14b_i2v_dmd.yaml",
"chars": 1439,
"preview": "# generator_ckpt: checkpoints/ode_init.pt\ni2v: true\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: si"
},
{
"path": "configs/self_forcing_dmd.yaml",
"chars": 1352,
"preview": "generator_ckpt: checkpoints/ode_init.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: size\nfake_scor"
},
{
"path": "configs/self_forcing_sid.yaml",
"chars": 1281,
"preview": "generator_ckpt: checkpoints/ode_init.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: size\nfake_scor"
},
{
"path": "convert_checkpoint.py",
"chars": 3568,
"preview": "import torch\nimport argparse\nimport os\nimport gc\nfrom safetensors.torch import save_file\n\ndef main():\n # Set up argum"
},
{
"path": "demo.py",
"chars": 22068,
"preview": "\"\"\"\nDemo for Self-Forcing.\n\"\"\"\n\nimport os\nimport time\nimport base64\nimport argparse\nimport urllib.request\nfrom io import"
},
{
"path": "demo_utils/constant.py",
"chars": 1352,
"preview": "\nimport torch\n\n\nZERO_VAE_CACHE = [\n torch.zeros(1, 16, 2, 60, 104),\n torch.zeros(1, 384, 2, 60, 104),\n torch.ze"
},
{
"path": "demo_utils/memory.py",
"chars": 4417,
"preview": "# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils\n# Apache-2.0 License\n# By lllyasviel\n\nimport "
},
{
"path": "demo_utils/taehv.py",
"chars": 14157,
"preview": "#!/usr/bin/env python3\n\"\"\"\nTiny AutoEncoder for Hunyuan Video\n(DNN for encoding / decoding videos to Hunyuan Video's lat"
},
{
"path": "demo_utils/utils.py",
"chars": 17547,
"preview": "# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils\n# Apache-2.0 License\n# By lllyasviel\n\nimport "
},
{
"path": "demo_utils/vae.py",
"chars": 15414,
"preview": "from typing import List\nfrom einops import rearrange\nimport tensorrt as trt\nimport torch\nimport torch.nn as nn\n\nfrom dem"
},
{
"path": "demo_utils/vae_block3.py",
"chars": 11058,
"preview": "from typing import List\nfrom einops import rearrange\nimport torch\nimport torch.nn as nn\n\nfrom wan.modules.vae import Att"
},
{
"path": "demo_utils/vae_torch2trt.py",
"chars": 11889,
"preview": "# ---- INT8 (optional) ----\nfrom demo_utils.vae import (\n VAEDecoderWrapperSingle, # main nn."
},
{
"path": "inference.py",
"chars": 6988,
"preview": "import argparse\nimport torch\nimport os\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom torchvision import tra"
},
{
"path": "model/__init__.py",
"chars": 278,
"preview": "from .diffusion import CausalDiffusion\nfrom .causvid import CausVid\nfrom .dmd import DMD\nfrom .gan import GAN\nfrom .sid "
},
{
"path": "model/base.py",
"chars": 11837,
"preview": "from typing import Tuple\nfrom einops import rearrange\nfrom torch import nn\nimport torch.distributed as dist\nimport torch"
},
{
"path": "model/causvid.py",
"chars": 17252,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\n\n\nclass CausVid("
},
{
"path": "model/diffusion.py",
"chars": 5641,
"preview": "from typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_wrapper import WanDiffusionWrappe"
},
{
"path": "model/dmd.py",
"chars": 16281,
"preview": "from pipeline import SelfForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\nimpo"
},
{
"path": "model/gan.py",
"chars": 14232,
"preview": "import copy\nfrom pipeline import SelfForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Tuple\nim"
},
{
"path": "model/ode_regression.py",
"chars": 5962,
"preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_w"
},
{
"path": "model/sid.py",
"chars": 12638,
"preview": "from pipeline import SelfForcingTrainingPipeline\nfrom typing import Optional, Tuple\nimport torch\n\nfrom model.base import"
},
{
"path": "pipeline/__init__.py",
"chars": 653,
"preview": "from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline\nfrom .bidirectional_inference imp"
},
{
"path": "pipeline/bidirectional_diffusion_inference.py",
"chars": 4146,
"preview": "from tqdm import tqdm\nfrom typing import List\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMultistepSched"
},
{
"path": "pipeline/bidirectional_inference.py",
"chars": 3109,
"preview": "from typing import List\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\n\n"
},
{
"path": "pipeline/bidirectional_training.py",
"chars": 4416,
"preview": "from typing import List\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import Sche"
},
{
"path": "pipeline/causal_diffusion_inference.py",
"chars": 16465,
"preview": "from tqdm import tqdm\nfrom typing import List, Optional\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMult"
},
{
"path": "pipeline/causal_inference.py",
"chars": 14154,
"preview": "from typing import List, Optional\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVA"
},
{
"path": "pipeline/self_forcing_training.py",
"chars": 13179,
"preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
},
{
"path": "prompts/MovieGenVideoBench.txt",
"chars": 106634,
"preview": "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black lea"
},
{
"path": "prompts/MovieGenVideoBench_extended.txt",
"chars": 656377,
"preview": "A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibra"
},
{
"path": "prompts/vbench/all_dimension.txt",
"chars": 39728,
"preview": "In a still frame, a stop sign\na toilet, frozen in time\na laptop, frozen in time\nA tranquil tableau of alley\nA tranquil t"
},
{
"path": "prompts/vbench/all_dimension_longer.txt",
"chars": 597797,
"preview": "In a still frame, a weathered stop sign stands prominently at a quiet intersection, its red paint slightly faded and edg"
},
{
"path": "requirements.txt",
"chars": 530,
"preview": "torch>=2.4.0\ntorchvision>=0.19.0\nopencv-python>=4.9.0.80\ndiffusers==0.31.0\ntransformers>=4.49.0\ntokenizers>=0.20.3\naccel"
},
{
"path": "scripts/compute_vae_latent.py",
"chars": 7599,
"preview": "from utils.wan_wrapper import WanVAEWrapper\nimport torch.distributed as dist\nimport imageio.v3 as iio\nfrom datetime impo"
},
{
"path": "scripts/create_lmdb_14b_shards.py",
"chars": 6092,
"preview": "\"\"\"\npython create_lmdb_14b_shards.py \\\n--data_path /mnt/localssd/wanx_14b_data \\\n--lmdb_path /mnt/localssd/wanx_14B_shif"
},
{
"path": "scripts/create_lmdb_iterative.py",
"chars": 1734,
"preview": "from tqdm import tqdm\nimport numpy as np\nimport argparse\nimport torch\nimport lmdb\nimport glob\nimport os\n\nfrom utils.lmdb"
},
{
"path": "scripts/generate_ode_pairs.py",
"chars": 3799,
"preview": "from utils.distributed import launch_distributed_job\nfrom utils.scheduler import FlowMatchScheduler\nfrom utils.wan_wrapp"
},
{
"path": "setup.py",
"chars": 129,
"preview": "from setuptools import setup, find_packages\nsetup(\n name=\"self_forcing\",\n version=\"0.0.1\",\n packages=find_packa"
},
{
"path": "templates/demo.html",
"chars": 24394,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width"
},
{
"path": "train.py",
"chars": 1629,
"preview": "import argparse\nimport os\nfrom omegaconf import OmegaConf\nimport wandb\n\nfrom trainer import DiffusionTrainer, GANTrainer"
},
{
"path": "trainer/__init__.py",
"chars": 297,
"preview": "from .diffusion import Trainer as DiffusionTrainer\nfrom .gan import Trainer as GANTrainer\nfrom .ode import Trainer as OD"
},
{
"path": "trainer/diffusion.py",
"chars": 10384,
"preview": "import gc\nimport logging\n\nfrom model import CausalDiffusion\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom ut"
},
{
"path": "trainer/distillation.py",
"chars": 19384,
"preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.dataset import TextDataset, Te"
},
{
"path": "trainer/gan.py",
"chars": 19999,
"preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.distributed import EMA_FSDP, f"
},
{
"path": "trainer/ode.py",
"chars": 9791,
"preview": "import gc\nimport logging\nfrom utils.dataset import ODERegressionLMDBDataset, cycle\nfrom model import ODERegression\nfrom "
},
{
"path": "utils/dataset.py",
"chars": 8272,
"preview": "from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb\nfrom torch.utils.data import Dataset\nimport num"
},
{
"path": "utils/distributed.py",
"chars": 4710,
"preview": "from datetime import timedelta\nfrom functools import partial\nimport os\nimport torch\nimport torch.distributed as dist\nfro"
},
{
"path": "utils/lmdb.py",
"chars": 2078,
"preview": "import numpy as np\n\n\ndef get_array_shape_from_lmdb(env, array_name):\n with env.begin() as txn:\n image_shape = "
},
{
"path": "utils/loss.py",
"chars": 2467,
"preview": "from abc import ABC, abstractmethod\nimport torch\n\n\nclass DenoisingLoss(ABC):\n @abstractmethod\n def __call__(\n "
},
{
"path": "utils/misc.py",
"chars": 1155,
"preview": "import numpy as np\nimport random\nimport torch\n\n\ndef set_seed(seed: int, deterministic: bool = False):\n \"\"\"\n Helper"
},
{
"path": "utils/scheduler.py",
"chars": 7979,
"preview": "from abc import abstractmethod, ABC\nimport torch\n\n\nclass SchedulerInterface(ABC):\n \"\"\"\n Base class for diffusion n"
},
{
"path": "utils/wan_wrapper.py",
"chars": 15538,
"preview": "import os\nimport types\nfrom typing import List, Optional\nimport torch\nfrom torch import nn\n\nfrom utils.scheduler import "
},
{
"path": "wan/README.md",
"chars": 92,
"preview": "Code in this folder is modified from https://github.com/Wan-Video/Wan2.1\nApache-2.0 License "
},
{
"path": "wan/__init__.py",
"chars": 107,
"preview": "from . import configs, distributed, modules\nfrom .image2video import WanI2V\nfrom .text2video import WanT2V\n"
},
{
"path": "wan/configs/__init__.py",
"chars": 1011,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .wan_t2v_14B import t2v_14B\nfrom .wan_t2v_"
},
{
"path": "wan/configs/shared_config.py",
"chars": 650,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n# -"
},
{
"path": "wan/configs/wan_i2v_14B.py",
"chars": 972,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
},
{
"path": "wan/configs/wan_t2v_14B.py",
"chars": 743,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/configs/wan_t2v_1_3B.py",
"chars": 760,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
},
{
"path": "wan/distributed/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "wan/distributed/fsdp.py",
"chars": 1077,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom functools import partial\n\nimport torch\nfro"
},
{
"path": "wan/distributed/xdit_context_parallel.py",
"chars": 5899,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
},
{
"path": "wan/image2video.py",
"chars": 13203,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/modules/__init__.py",
"chars": 365,
"preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
},
{
"path": "wan/modules/attention.py",
"chars": 5641,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n import flash_attn_interf"
},
{
"path": "wan/modules/causal_model.py",
"chars": 42167,
"preview": "from wan.modules.attention import attention\nfrom wan.modules.model import (\n WanRMSNorm,\n rope_apply,\n WanLayer"
},
{
"path": "wan/modules/clip.py",
"chars": 15919,
"preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
},
{
"path": "wan/modules/model.py",
"chars": 30768,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn"
},
{
"path": "wan/modules/t5.py",
"chars": 16910,
"preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
},
{
"path": "wan/modules/tokenizers.py",
"chars": 2431,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
},
{
"path": "wan/modules/vae.py",
"chars": 23735,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
},
{
"path": "wan/modules/xlm_roberta.py",
"chars": 4865,
"preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
},
{
"path": "wan/text2video.py",
"chars": 10241,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
},
{
"path": "wan/utils/__init__.py",
"chars": 339,
"preview": "from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,\n retrieve_timeste"
},
{
"path": "wan/utils/fm_solvers.py",
"chars": 40232,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
},
{
"path": "wan/utils/fm_solvers_unipc.py",
"chars": 32645,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
},
{
"path": "wan/utils/prompt_extend.py",
"chars": 30344,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
},
{
"path": "wan/utils/qwen_vl_utils.py",
"chars": 13044,
"preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
},
{
"path": "wan/utils/utils.py",
"chars": 3239,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
}
]
About this extraction
This page contains the full source code of the GoatWu/Self-Forcing-Plus GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 83 files (2.0 MB), approximately 534.2k tokens, and a symbol index with 598 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.