Full Code of thu-ml/Causal-Forcing for AI

main 9d7fcaf94a54 cached
154 files
1.4 MB
339.4k tokens
1104 symbols
1 requests
Download .txt
Showing preview only (1,477K chars total). Download the full file or copy to clipboard to get everything.
Repository: thu-ml/Causal-Forcing
Branch: main
Commit: 9d7fcaf94a54
Files: 154
Total size: 1.4 MB

Directory structure:
gitextract_cbc4vx5e/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── ar_diffusion_tf_chunkwise.yaml
│   ├── ar_diffusion_tf_framewise.yaml
│   ├── causal_cd_chunkwise.yaml
│   ├── causal_cd_framewise.yaml
│   ├── causal_forcing_dmd_chunkwise.yaml
│   ├── causal_forcing_dmd_framewise.yaml
│   ├── causal_forcing_dmd_framewise_1step.yaml
│   ├── causal_forcing_dmd_framewise_2step.yaml
│   ├── causal_ode_chunkwise.yaml
│   ├── causal_ode_framewise.yaml
│   └── default_config.yaml
├── demo.py
├── demo_utils/
│   ├── constant.py
│   ├── memory.py
│   ├── taehv.py
│   ├── utils.py
│   ├── vae.py
│   ├── vae_block3.py
│   └── vae_torch2trt.py
├── get_causal_ode_data_chunkwise.py
├── get_causal_ode_data_framewise.py
├── get_causal_ode_data_kv_optimized.py
├── inference.py
├── long_video/
│   ├── LICENSE
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── default_config.yaml
│   │   └── rolling_forcing_dmd.yaml
│   ├── inference.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── causvid.py
│   │   ├── diffusion.py
│   │   ├── dmd.py
│   │   ├── gan.py
│   │   ├── ode_regression.py
│   │   └── sid.py
│   ├── pipeline/
│   │   ├── __init__.py
│   │   ├── bidirectional_diffusion_inference.py
│   │   ├── bidirectional_inference.py
│   │   ├── causal_diffusion_inference.py
│   │   ├── rolling_forcing_inference.py
│   │   └── rolling_forcing_training.py
│   ├── prompts/
│   │   └── example_prompts.txt
│   ├── requirements.txt
│   ├── train.py
│   ├── trainer/
│   │   ├── __init__.py
│   │   ├── diffusion.py
│   │   ├── distillation.py
│   │   ├── gan.py
│   │   └── ode.py
│   ├── utils/
│   │   ├── dataset.py
│   │   ├── distributed.py
│   │   ├── lmdb.py
│   │   ├── loss.py
│   │   ├── misc.py
│   │   ├── scheduler.py
│   │   └── wan_wrapper.py
│   └── wan/
│       ├── README.md
│       ├── __init__.py
│       ├── configs/
│       │   ├── __init__.py
│       │   ├── shared_config.py
│       │   ├── wan_i2v_14B.py
│       │   ├── wan_t2v_14B.py
│       │   └── wan_t2v_1_3B.py
│       ├── distributed/
│       │   ├── __init__.py
│       │   ├── fsdp.py
│       │   └── xdit_context_parallel.py
│       ├── image2video.py
│       ├── modules/
│       │   ├── __init__.py
│       │   ├── attention.py
│       │   ├── causal_model.py
│       │   ├── clip.py
│       │   ├── model.py
│       │   ├── t5.py
│       │   ├── tokenizers.py
│       │   ├── vae.py
│       │   └── xlm_roberta.py
│       ├── text2video.py
│       └── utils/
│           ├── __init__.py
│           ├── fm_solvers.py
│           ├── fm_solvers_unipc.py
│           ├── prompt_extend.py
│           ├── qwen_vl_utils.py
│           └── utils.py
├── model/
│   ├── __init__.py
│   ├── base.py
│   ├── causvid.py
│   ├── diffusion.py
│   ├── dmd.py
│   ├── gan.py
│   ├── naive_consistency.py
│   ├── ode_regression.py
│   └── sid.py
├── pipeline/
│   ├── __init__.py
│   ├── bidirectional_diffusion_inference.py
│   ├── bidirectional_inference.py
│   ├── bidirectional_training.py
│   ├── causal_diffusion_inference.py
│   ├── causal_inference.py
│   ├── self_forcing_training.py
│   └── teacher_forcing_training.py
├── prompts/
│   ├── demos.txt
│   └── i2v/
│       └── target_crop_info_26-15.json
├── requirements.txt
├── setup.py
├── train.py
├── trainer/
│   ├── __init__.py
│   ├── diffusion.py
│   ├── distillation.py
│   ├── gan.py
│   ├── naive_cd.py
│   └── ode.py
├── utils/
│   ├── create_lmdb_iterative.py
│   ├── dataset.py
│   ├── distributed.py
│   ├── lmdb_.py
│   ├── loss.py
│   ├── merge_and_get_clean.py
│   ├── merge_lmdb.py
│   ├── misc.py
│   ├── ode_generation.py
│   ├── scheduler.py
│   └── wan_wrapper.py
└── wan/
    ├── README.md
    ├── __init__.py
    ├── configs/
    │   ├── __init__.py
    │   ├── shared_config.py
    │   ├── wan_i2v_14B.py
    │   ├── wan_t2v_14B.py
    │   └── wan_t2v_1_3B.py
    ├── distributed/
    │   ├── __init__.py
    │   ├── fsdp.py
    │   └── xdit_context_parallel.py
    ├── image2video.py
    ├── modules/
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── causal_model.py
    │   ├── clip.py
    │   ├── model.py
    │   ├── t5.py
    │   ├── tokenizers.py
    │   ├── vae.py
    │   └── xlm_roberta.py
    ├── text2video.py
    └── utils/
        ├── __init__.py
        ├── fm_solvers.py
        ├── fm_solvers_unipc.py
        ├── prompt_extend.py
        ├── qwen_vl_utils.py
        └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__
*.egg-info

wan_models
checkpoints
output
dataset
prompts/vidprom_filtered_extended.txt
logs

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">

## Causal Forcing & Causal Forcing++
### Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation

<p align="center">
  <p align="center">
    <div>
    <a href="https://zhuhz22.github.io/" target="_blank">Hongzhou Zhu*</a><sup></sup>,
    <a href="https://gracezhao1997.github.io/" target="_blank">Min Zhao*</a><sup></sup> , 
    <a href="https://guandehe.github.io/" target="_blank">Guande He</a><sup></sup>, 
    <a href="https://scholar.google.com/citations?user=dxN1_X0AAAAJ&hl=en" target="_blank">Hang Su</a><sup></sup>,
    <a href="https://zhenxuan00.github.io/" target="_blank">Chongxuan Li</a><sup></sup> ,
    <a href="https://ml.cs.tsinghua.edu.cn/~jun/index.shtml" target="_blank">Jun Zhu</a><sup></sup>
  </div>
  <div>
      <sup></sup>Tsinghua University & Shengshu & UT Austin & RUC
  </div>


</div>
  </p>
  <h3 align="center"><a href="https://arxiv.org/abs/2602.02214">Causal Forcing</a> | <a href="https://arxiv.org/abs/2605.15141">Causal Forcing++</a> | <a href="https://thu-ml.github.io/CausalForcing.github.io">Website</a> | <a href="https://huggingface.co/zhuhz22/Causal-Forcing/tree/main">Models</a> | <a href="assets/wechat.jpg">WeChat</a>  | <a href="https://my.feishu.cn/wiki/AjBSwcjpqiN0ECkodIWcGDcMn4e?from=from_copylink">Document</a> </h3>
</p>



-----
The Causal Forcing series uses **Causal ODE** or **Causal Consistency Distillation** to drive asymmetric DMD as a theoretically correct initialization for real-time interactive video generation.

[Causal Forcing](https://arxiv.org/abs/2602.02214) significantly outperforms Self Forcing in **both visual quality and motion dynamics**, while keeping **the same training budget and inference efficiency**. We support both chunk-wise and **frame-wise** models, with the latter natively unifying T2V and **I2V**.

We further propose [**Causal Forcing++**](https://arxiv.org/abs/2605.15141), replacing ODE with **causal Consistency Distillation** to eliminate ODE data curation and improve performance, releasing the first **1-step/2-step frame-wise** models.

-----
<img width="2090" height="850" alt="overview" src="assets/pipeline.png" />

## Table of Contents

- [Quick Start](#quick-start)
  - [Installation](#installation) 
  - [Inference: 1/2/4-step T2V & I2V](#cli-inference)
  - [Long Video](#minute-level-long-video-generation)
- [Training Pipeline](#training) (Causal Foricng / Causal Forcing++)
  - [Stage 1: AR Diffusion](#training)
  - [Stage 2: Causal ODE (Causal Forcing)](#training) / [🔥Causal CD (Causal Forcing++)](#-stage-2-option-b-causal-consistency-distillation-initialization-causal-forcing)
  - [Stage 3: Asymmetric DMD](#stage-3-dmd)


| Models              | Checkpoints | Description|
|--------------------|---------------------------------------------------------------------------------------------------------------------------------------------|-------|
| Chunk-wise 4-step   | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/chunkwise/causal_forcing.pt)    | SOTA AR model on Wan1.3B outperforming Self Forcing|
| Frame-wise 4-step   | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/framewise/causal_forcing.pt)    | Rich dynamics and high quality |
| Frame-wise 2-step🔥 | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/causal-forcing%2B%2B/framewise-2step.pt)    | The  first frame-wise 2-step model **even better than 4-step**|
| Frame-wise 1-step   | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/causal-forcing%2B%2B/framewise-1step.pt)    | Extremely low latency |
| Long Video Generator   | 🤗 [Huggingface](https://huggingface.co/zhuhz22/Causal-Forcing/blob/main/chunkwise/longvideo.pt)    | Minute-long AR video generator |
| HY1.5-TI2V (8B)   | 🤗 [Huggingface](https://huggingface.co/MIN-Lab/minWM)   | Refer to [this repo](https://github.com/shengshu-ai/minWM). |
| Action-conditioned WM   | 🤗 [Huggingface](https://huggingface.co/MIN-Lab/minWM)    | Refer to [this repo](https://github.com/shengshu-ai/minWM). |

-----

https://github.com/user-attachments/assets/310f0cfa-e1bb-496d-8941-87f77b3271c0

## 🔥 News
- **2026.5.15**: We release [Causal Forcing++](https://arxiv.org/abs/2605.15141), supporting Casual Consistency Distillation for few-step initialization, and open-source **the first frame-wise 2-step AR model** comparable to chunk-wise 4-step models!
- **2026.5.10**: Thanks to @[AshadowZ](https://github.com/AshadowZ), now our chunk-wise ODE data curation is **3x faster**!
- **2026.4.16**: **Optimize the Stage 2 🔥consistency distillation🔥 infrastructure for 3× faster training, let's try it now!** We have also released the ckpt.
- **2026.3.15** : [Rolling Sink](https://github.com/haodong2000/RollingSink), [Infinity-RoPE](https://github.com/yesiltepe-hidir/infinity-rope) and [Deep Forcing](https://cvlab-kaist.github.io/DeepForcing/) adopt Causal Forcing as one of the base models!
- **2026.2.28** : Add [FAQ section](#faq--blog) regarding hot topics, specifically which is the better Initialization between AR diffusion and causal ODE distillation.
- **2026.2.11** : We now support **I2V** generation! Feel free to try it [here](#new-i2v)!
- **2026.2.7** : Causal Forcing now supports [Rolling Forcing](https://github.com/TencentARC/RollingForcing), enabling minute-level long video generation!
- **2026.2.5** : Release causal consistency distillation (Preview) as substitute for ODE distillation, **free of generating ODE paired data**!
- **2026.2.2** : The [paper](https://arxiv.org/abs/2602.02214), [project page](https://thu-ml.github.io/CausalForcing.github.io/), and code are released.


## Quick Start

> The inference environment is identical to Self Forcing.

**NOTE**: Similar to CausVid/Self Forcing, Causal Forcing does not natively support videos longer than 81 frames. As a base training method, it is orthogonal to techniques like Longlive/Rolling Forcing. To use Causal Forcing as a long video baseline, see [this extension](#minute-level-long-video-generation). **Directly using the 5-second trained Causal Forcing model as a baseline for long video generation is extremely unfair**.


### Installation
```bash
conda create -n causal_forcing python=3.10 -y
conda activate causal_forcing
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
pip install flash-attn --no-build-isolation
python setup.py develop
```
### Download Checkpoints
```bash
hf download Wan-AI/Wan2.1-T2V-1.3B  --local-dir wan_models/Wan2.1-T2V-1.3B
hf download Wan-AI/Wan2.1-T2V-14B  --local-dir wan_models/Wan2.1-T2V-14B
# Causal Forcing
hf download zhuhz22/Causal-Forcing chunkwise/causal_forcing.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing framewise/causal_forcing.pt --local-dir checkpoints
# Causal Forcing++
hf download zhuhz22/Causal-Forcing causal-forcing++/framewise-2step.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing causal-forcing++/framewise-1step.pt --local-dir checkpoints
```

### CLI Inference

#### T2V
Chunk-wise model:
```bash
python inference.py \
  --config_path configs/causal_forcing_dmd_chunkwise.yaml \
  --output_folder output/chunkwise \
  --checkpoint_path checkpoints/chunkwise/causal_forcing.pt \
  --data_path prompts/demos.txt
```

Frame-wise model:

```bash
# =============== Causal Forcing++ ================
# 2-step Causal Forcing++
python inference.py \
  --config_path configs/causal_forcing_dmd_framewise_2step.yaml \
  --output_folder output/framewise_2step_cf++ \
  --checkpoint_path checkpoints/causal-forcing++/framewise-2step.pt \
  --data_path prompts/demos.txt \
  --use_ema

# 1-step Causal Forcing++
python inference.py \
  --config_path configs/causal_forcing_dmd_framewise_1step.yaml \
  --output_folder output/framewise_1step_cf++ \
  --checkpoint_path checkpoints/causal-forcing++/framewise-1step.pt \
  --data_path prompts/demos.txt \
  --use_ema

# =============== Causal Forcing ================
# 4-step Causal Forcing
python inference.py \
  --config_path configs/causal_forcing_dmd_framewise.yaml \
  --output_folder output/framewise \
  --checkpoint_path  checkpoints/framewise/causal_forcing.pt \
  --data_path prompts/demos.txt \
  --use_ema 
```

#### I2V
> Our frame-wise setting natively supports I2V. You simply need to set the first latent initial frame as your conditional image. 

```bash
python inference.py \
  --config_path configs/causal_forcing_dmd_framewise.yaml \
  --output_folder output/framewise \
  --checkpoint_path  checkpoints/framewise/causal_forcing.pt \
  --data_path prompts/i2v \
  --i2v \
  --use_ema
```


### Minute-level Long Video Generation
Built on [Rolling Forcing](https://github.com/TencentARC/RollingForcing), we implemented minute-level long video generation. See [here](./long_video) for the detail.

[Infinity-RoPE](https://github.com/yesiltepe-hidir/infinity-rope), [Deep Forcing](https://cvlab-kaist.github.io/DeepForcing/) and [Rolling Sink](https://github.com/haodong2000/RollingSink) also adopt Causal Forcing as one of their base models, enabling interactive (prompt-switchable) long video generation at the minute scale. You can also try them out at their repos.

## Training

<details>
<summary> Stage 1: Autoregressive Diffusion Training (Can skip by using our pretrained checkpoints. Click to expand.)</summary>

First download the dataset (we provide a 6K toy dataset here):
```bash
hf download zhuhz22/Causal-Forcing-data  --local-dir dataset
python utils/merge_and_get_clean.py
```
> If the download gets stuck, Ctrl^C and then resume it.


> For training on your own dataset, refer to [this issue](https://github.com/thu-ml/Causal-Forcing/issues/8).


Then train the AR-diffusion model:
- Framewise:
  ```bash
    torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/ar_diffusion_tf_framewise.yaml \
    --logdir logs/ar_diffusion_framewise
  ```

- Chunkwise:
  ```bash
    torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/ar_diffusion_tf_chunkwise.yaml \
    --logdir logs/ar_diffusion_chunkwise
  ```

> We recommend training no less than 2K steps, and more steps (e.g., 5~10K) will lead to better performance.

Inference to test training results:
```bash
python inference.py \
  --config_path configs/ar_diffusion_tf_{framewise OR chunkwise}.yaml \
  --output_folder output/{framewise OR chunkwise}_ar_diffusion \
  --checkpoint_path  checkpoints/{framewise OR chunkwise}/ar_diffusion.pt \
  --data_path prompts/demos.txt
```
</details>


<details>
<summary> Stage 2 (Option a): Causal ODE Initialization (Can skip by using our pretrained checkpoints. Click to expand.)</summary>

🔥 Thanks to @[AshadowZ](https://github.com/AshadowZ), now our chunk-wise ODE data curation is **3x faster**!

You can also use `bf16` to accelerate generation.

If you have skipped Stage 1, you need to download the pretrained models:
```bash
hf download zhuhz22/Causal-Forcing framewise/ar_diffusion.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/ar_diffusion.pt --local-dir checkpoints
```

In this stage, first generate ODE paired data:
```bash
# for the frame-wise model
torchrun --nproc_per_node=8 \
  get_causal_ode_data_framewise.py \
  --generator_ckpt checkpoints/framewise/ar_diffusion.pt \
  --rawdata_path dataset/clean_data \
  --output_folder dataset/ODE6KCausal_framewise_latents

python utils/create_lmdb_iterative.py \
  --data_path dataset/ODE6KCausal_framewise_latents \
  --lmdb_path dataset/ODE6KCausal_framewise

# for the chunk-wise model
torchrun --nproc_per_node=8 \
  get_causal_ode_data_chunkwise.py \
  --generator_ckpt checkpoints/chunkwise/ar_diffusion.pt \
  --rawdata_path dataset/clean_data \
  --output_folder dataset/ODE6KCausal_chunkwise_latents

python utils/create_lmdb_iterative.py \
  --data_path dataset/ODE6KCausal_chunkwise_latents \
  --lmdb_path dataset/ODE6KCausal_chunkwise

#🔥NEW: Or you can use the optimized code (3x speedup) by @AshadowZ 
torchrun --nproc_per_node=8 \
  get_causal_ode_data_kv.py \
  --generator_ckpt checkpoints/{chunkwise,framewise}/ar_diffusion.pt \
  --rawdata_path /mnt/vepfs/base2/zhaomin/group_0001_lmdb \
  --output_folder dataset/ODE6KCausal_{chunkwise,framewise}_latents \
  --num_frames_per_chunk 3/1 # 3 for chunkwise, 1 for framewise
  --generation_mode blockwise_kv
```

Or you can also directly download our prepared dataset (~300G):
```bash
hf download zhuhz22/Causal-Forcing-data  --local-dir dataset
python utils/merge_lmdb.py
```
> If the download gets stuck, Ctrl^C and then resume it.


And then train ODE initialization models:
- Frame-wise:
  ```bash
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_ode_framewise.yaml \
    --logdir logs/causal_ode_framewise
  ```
- Chunk-wise:
  ```bash
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_ode_chunkwise.yaml \
    --logdir logs/causal_ode_chunkwise
  ```

> We recommend training no less than 1K steps, and more steps (e.g., 5~10K) will lead to better performance.

Inference to test training results:

The same as [here](#cli-inference).
</details>



### 🔥 Stage 2 (Option b): Causal Consistency Distillation Initialization ([Causal Forcing++](https://arxiv.org/abs/2605.15141))
Since creating ODE-paired data is very time-consuming, we also provide an alternative here that achieves the same effect (or better) as ODE distillation while requiring only ground-truth data, **free of generating ODE data!**

> Thanks to [@chijw's effort](https://github.com/thu-ml/Causal-Forcing/pull/20), now the EMA mechanism is more efficient!

- Frame-wise:
  ```bash
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_cd_framewise.yaml \
    --logdir logs/causal_cd_framewise
  ```
- Chunk-wise:
  ```bash
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_cd_chunkwise.yaml \
    --logdir logs/causal_cd_chunkwise
  ```

> We recommend training no less than 3K steps, and more steps (e.g., 5~10K) will lead to better performance.

<details>
<summary>
You can also download the checkpoints directly (click to expand) :
</summary>

```bash
hf download zhuhz22/Causal-Forcing framewise/causal_cd.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_cd.pt --local-dir checkpoints
```
</details>

Inference to test training results: the same as [here](#cli-inference).




### Stage 3: DMD

First download the dataset:
```bash
hf download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
```

<details>

<summary>
If you have skipped Stage 2, you need to download our pretrained checkpoints (click to expand):
</summary>

```bash
hf download zhuhz22/Causal-Forcing framewise/causal_ode.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_ode.pt --local-dir checkpoints

# 🔥 Or you can also use the improved version (Causal Forcing++): causal CD instead of causal ODE
hf download zhuhz22/Causal-Forcing framewise/causal_cd.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_cd.pt --local-dir checkpoints
```
</details>

And then train DMD models:

- Frame-wise model:
  ```bash
  # =============== Causal Forcing ================
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_forcing_dmd_framewise.yaml \
    --logdir logs/causal_forcing_dmd_framewise
  # =============== Causal Forcing++ ================
  # 2-step Causal Forcing++
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_forcing_dmd_framewise_2step.yaml \
    --logdir logs/causal_forcing_dmd_framewise

  # 1-step Causal Forcing++
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_forcing_dmd_framewise_1step.yaml \
    --logdir logs/causal_forcing_dmd_framewise
  ```

- Chunk-wise model:
  ```bash
  torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
    --rdzv_backend=c10d \
    --rdzv_endpoint $MASTER_ADDR \
    train.py \
    --config_path configs/causal_forcing_dmd_chunkwise.yaml \
    --logdir logs/causal_forcing_dmd_chunkwise
  ```
> For bs64, we recommend training for no more than 1K steps; otherwise the motion dynamics may degrade. If bs is smaller (e.g., 8), more training steps is preferable.

Such models are the final models used to generate videos.
## FAQ & Blog 
See the [FAQ](https://my.feishu.cn/wiki/AjBSwcjpqiN0ECkodIWcGDcMn4e) and the [blog](https://zhuanlan.zhihu.com/p/2002114039493461457). (currently in Chinese)

<details>
<summary> Typical Questions (click to expand)
</summary>

**Why using bidirectional teacher in the DMD stage ?**
- Q: In the DMD stage, do you still use a bidirectional teacher? Why not an AR teacher?
- A: Yes. DMD only requires the student to match the teacher’s final distribution, not the generation trajectory, so a bidirectional teacher is fine. Also, bidirectional diffusion models are typically stronger than AR diffusion, so they make a better teacher.
  
- Q: Then why must the ODE (or Consistency Distillation) stage use an AR teacher?
- A: Because ODE/CD requires the student and teacher to follow the same trajectory, so their structures must be matched; an AR student cannot be trajectory-aligned with a bidirectional teacher.
  
🔥🔥 **ODE initialization or multi-step AR diffusion initialization ?**
- Q: Which is better as initialization: a “proper” ODE initialization or directly using multi-step AR diffusion?
- A: We compared this in the Appendix C2. Overall, proper ODE init is better: multi-step AR diffusion init + DMD occasionally yields grid-like or waxy/greasy results. A key reason is that DMD is inherently few-step, so the right comparison is under few-step; in that regime, a few-step diffusion teacher is much weaker than an ODE-distilled teacher. Without ODE distillation, DMD must both close the step gap and handle an added conditioning gap from self-rollout: early few-step errors corrupt the history and get amplified across frames (large exposure bias), which increases DMD pressure. It can still converge, but typically with worse quality than ODE initialization. Also, with ODE init, DMD can be trained very few steps (e.g., ~100), reducing the risk of dynamics degradation from long DMD training.
  
**Can frame-level non-injectivity appears in the actual training dataset ?**
- Q: Regarding the “one-to-many” analysis in the ODE stage: since a single frame’s latent has very high dimensionality, isn’t the probability of being exactly identical extremely small?
- A: Yes, but the key point here is not whether the dataset literally contains identical samples; it’s whether there exists a well-defined function in the mathematical sense. Our vision modalities live in a continuous space—even in 1D, getting two samples to be exactly identical is extremely unlikely. However, the theoretical existence of exact collisions is enough to break the function property and make it ill-defined.

</details>

## Acknowledgements
- This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid), [Self Forcing](https://github.com/guandeh17/Self-Forcing), [Rolling Forcing](https://github.com/TencentARC/RollingForcing) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo. 
- Thanks to @[chijw](https://github.com/chijw) for improving the EMA mechanism. 
- Thanks to @[AshadowZ](https://github.com/AshadowZ) for improving the causal ODE data curation efficiency.
- Causal Forcing++ with 1/2 steps applies the first-frame 4-step technique proposed by [ASD](https://github.com/BigAandSmallq/SAD). We thank ASD for its contribution.

## References
If you find the method useful, please cite
```
@article{zhu2026causal,
  title={Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation},
  author={Zhu, Hongzhou and Zhao, Min and He, Guande and Su, Hang and Li, Chongxuan and Zhu, Jun},
  journal={arXiv preprint arXiv:2602.02214},
  year={2026}
}

@article{zhao2026causal,
  title={Causal Forcing++: Scalable Few-Step Autoregressive Diffusion Distillation for Real-Time Interactive Video Generation},
  author={Zhao, Min and Zhu, Hongzhou and Zheng, Kaiwen and Zhou, Zihan and Yan, Bokai and Li, Xinyuan and Yang, Xiao and Li, Chongxuan and Zhu, Jun},
  journal={arXiv preprint arXiv:2605.15141},
  year={2026}
}
```


================================================
FILE: configs/ar_diffusion_tf_chunkwise.yaml
================================================
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
warp_denoising_step: true 
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 2000000000
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: diffusion
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: dataset/clean_data 
teacher_forcing: true # TF/DF


================================================
FILE: configs/ar_diffusion_tf_framewise.yaml
================================================
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
warp_denoising_step: true 
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 2000000000
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: diffusion
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: dataset/clean_data 
teacher_forcing: true # TF/DF


================================================
FILE: configs/causal_cd_chunkwise.yaml
================================================
generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: consistency_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: dataset/clean_data
grad_accum: 1
spatial_sf: true
teacher_forcing: True
causal: True
is_causal: True
discrete_cd_N: 48


================================================
FILE: configs/causal_cd_framewise.yaml
================================================
generator_ckpt: checkpoints/framewise/ar_diffusion.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-1.3B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: consistency_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: dataset/clean_data
grad_accum: 1
spatial_sf: true
teacher_forcing: True
causal: True
is_causal: True
discrete_cd_N: 48


================================================
FILE: configs/causal_forcing_dmd_chunkwise.yaml
================================================
generator_ckpt: checkpoints/chunkwise/causal_ode.pt # 🔥 or checkpoints/chunkwise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt 

================================================
FILE: configs/causal_forcing_dmd_framewise.yaml
================================================
generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt

================================================
FILE: configs/causal_forcing_dmd_framewise_1step.yaml
================================================
generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
denoising_step_list_first_chunk: #  Causal Forcing++ with 1/2 steps applies the first-frame 4-step technique proposed by [ASD](https://github.com/BigAandSmallq/SAD). We thank ASD for its contribution.
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt

================================================
FILE: configs/causal_forcing_dmd_framewise_2step.yaml
================================================
generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 500
denoising_step_list_first_chunk: #  Causal Forcing++ with 1/2 steps applies the first-frame 4-step technique proposed by [ASD](https://github.com/BigAandSmallq/SAD). We thank ASD for its contribution.
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
sharding_strategy: hybrid_full
lr: 2.0e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 8
log_iters: 250
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 1
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0
data_path: prompts/vidprom_filtered_extended.txt

================================================
FILE: configs/causal_ode_chunkwise.yaml
================================================
generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt
generator_grad:
  model: true
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true
generator_task: causal_video
generator_fsdp_wrap_strategy: size
text_encoder_fsdp_wrap_strategy: size
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
wandb_name: ode
sharding_strategy: hybrid_full
lr: 2.0e-06
beta1: 0.9
beta2: 0.999
data_path: dataset/ODE6KCausal_chunkwise
batch_size: 1
log_iters: 1000
trainer: ode
gradient_checkpointing: true
num_frame_per_block: 3
model_kwargs:
  timestep_shift: 5.0

================================================
FILE: configs/causal_ode_framewise.yaml
================================================
generator_ckpt: checkpoints/framewise/ar_diffusion.pt
generator_grad:
  model: true
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true
generator_task: causal_video
generator_fsdp_wrap_strategy: size
text_encoder_fsdp_wrap_strategy: size
mixed_precision: true
seed: 0
wandb_host: https://api.wandb.ai
wandb_key: {your key}
wandb_entity: {your entity}
wandb_project: {your project}
wandb_name: ode
sharding_strategy: hybrid_full
lr: 2.0e-06
beta1: 0.9
beta2: 0.999
data_path: dataset/ODE6KCausal_framewise
batch_size: 1
log_iters: 1000
trainer: ode
gradient_checkpointing: true
num_frame_per_block: 1 
model_kwargs:
  timestep_shift: 5.0


================================================
FILE: configs/default_config.yaml
================================================
independent_first_frame: false
warp_denoising_step: false
weight_decay: 0.01
same_step_across_blocks: true
discriminator_lr_multiplier: 1.0
last_step_only: false
i2v: false
num_training_frames: 21
gc_interval: 100
context_noise: 0
causal: true

ckpt_step: 0
prompt_name: MovieGenVideoBench
prompt_path: prompts/MovieGenVideoBench.txt
eval_first_n: 64
num_samples: 1
height: 480
width: 832
num_frames: 81

================================================
FILE: demo.py
================================================
"""
Demo for Causal-Forcing.
"""

import os
import re
import random
import time
import base64
import argparse
import hashlib
import subprocess
import urllib.request
from io import BytesIO
from PIL import Image
import numpy as np
import torch
from omegaconf import OmegaConf
from flask import Flask, render_template, jsonify
from flask_socketio import SocketIO, emit
import queue
from threading import Thread, Event

from pipeline import CausalInferencePipeline
from demo_utils.constant import ZERO_VAE_CACHE
from demo_utils.vae_block3 import VAEDecoderWrapper
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
from demo_utils.utils import generate_timestamp
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=5001)
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument("--checkpoint_path", type=str, default='.checkpoints/framewise/causal_forcing.pt')
parser.add_argument("--config_path", type=str, default='.configs/causal_forcing_dmd_framewise.yaml')
parser.add_argument('--trt', action='store_true')
args = parser.parse_args()

print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
low_memory = get_cuda_free_memory_gb(gpu) < 40

# Load models
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)

text_encoder = WanTextEncoder()

# Global variables for dynamic model switching
current_vae_decoder = None
current_use_taehv = False
fp8_applied = False
torch_compile_applied = False
global frame_number
frame_number = 0
anim_name = ""
frame_rate = 6

def initialize_vae_decoder(use_taehv=False, use_trt=False):
    """Initialize VAE decoder based on the selected option"""
    global current_vae_decoder, current_use_taehv

    if use_trt:
        from demo_utils.vae import VAETRTWrapper
        current_vae_decoder = VAETRTWrapper()
        return current_vae_decoder

    if use_taehv:
        from demo_utils.taehv import TAEHV
        # Check if taew2_1.pth exists in checkpoints folder, download if missing
        taehv_checkpoint_path = "checkpoints/taew2_1.pth"
        if not os.path.exists(taehv_checkpoint_path):
            print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
            os.makedirs("checkpoints", exist_ok=True)
            download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
            try:
                urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
                print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
            except Exception as e:
                print(f"Failed to download taew2_1.pth: {e}")
                raise

        class DotDict(dict):
            __getattr__ = dict.__getitem__
            __setattr__ = dict.__setitem__

        class TAEHVDiffusersWrapper(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.dtype = torch.float16
                self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
                self.config = DotDict(scaling_factor=1.0)

            def decode(self, latents, return_dict=None):
                # n, c, t, h, w = latents.shape
                # low-memory, set parallel=True for faster + higher memory
                return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)

        current_vae_decoder = TAEHVDiffusersWrapper()
    else:
        current_vae_decoder = VAEDecoderWrapper()
        vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
        decoder_state_dict = {}
        for key, value in vae_state_dict.items():
            if 'decoder.' in key or 'conv2' in key:
                decoder_state_dict[key] = value
        current_vae_decoder.load_state_dict(decoder_state_dict)

    current_vae_decoder.eval()
    current_vae_decoder.to(dtype=torch.float16)
    current_vae_decoder.requires_grad_(False)
    current_vae_decoder.to(gpu)
    current_use_taehv = use_taehv

    print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
    return current_vae_decoder


# Initialize with default VAE
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)

transformer = WanDiffusionWrapper(is_causal=True)
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
transformer.load_state_dict(state_dict['generator_ema'])

text_encoder.eval()
transformer.eval()

transformer.to(dtype=torch.float16)
text_encoder.to(dtype=torch.bfloat16)

text_encoder.requires_grad_(False)
transformer.requires_grad_(False)

pipeline = CausalInferencePipeline(
    config,
    device=gpu,
    generator=transformer,
    text_encoder=text_encoder,
    vae=vae_decoder
)

if low_memory:
    DynamicSwapInstaller.install_model(text_encoder, device=gpu)
else:
    text_encoder.to(gpu)
transformer.to(gpu)

# Flask and SocketIO setup
app = Flask(__name__)
app.config['SECRET_KEY'] = 'frontend_buffered_demo'
socketio = SocketIO(app, cors_allowed_origins="*")

generation_active = False
stop_event = Event()
frame_send_queue = queue.Queue()
sender_thread = None
models_compiled = False


def tensor_to_base64_frame(frame_tensor):
    """Convert a single frame tensor to base64 image string."""
    global frame_number, anim_name
    # Clamp and normalize to 0-255
    frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
    frame = frame.to(torch.uint8).cpu().numpy()

    # CHW -> HWC
    if len(frame.shape) == 3:
        frame = np.transpose(frame, (1, 2, 0))

    # Convert to PIL Image
    if frame.shape[2] == 3:  # RGB
        image = Image.fromarray(frame, 'RGB')
    else:  # Handle other formats
        image = Image.fromarray(frame)

    # Convert to base64
    buffer = BytesIO()
    image.save(buffer, format='JPEG', quality=100)
    if not os.path.exists("./images/%s" % anim_name):
        os.makedirs("./images/%s" % anim_name)
    frame_number += 1
    image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
    img_str = base64.b64encode(buffer.getvalue()).decode()
    return f"data:image/jpeg;base64,{img_str}"


def frame_sender_worker():
    """Background thread that processes frame send queue non-blocking."""
    global frame_send_queue, generation_active, stop_event

    print("📡 Frame sender thread started")

    while True:
        frame_data = None
        try:
            # Get frame data from queue
            frame_data = frame_send_queue.get(timeout=1.0)

            if frame_data is None:  # Shutdown signal
                frame_send_queue.task_done()  # Mark shutdown signal as done
                break

            frame_tensor, frame_index, block_index, job_id = frame_data

            # Convert tensor to base64
            base64_frame = tensor_to_base64_frame(frame_tensor)

            # Send via SocketIO
            try:
                socketio.emit('frame_ready', {
                    'data': base64_frame,
                    'frame_index': frame_index,
                    'block_index': block_index,
                    'job_id': job_id
                })
            except Exception as e:
                print(f"⚠️ Failed to send frame {frame_index}: {e}")

            frame_send_queue.task_done()

        except queue.Empty:
            # Check if we should continue running
            if not generation_active and frame_send_queue.empty():
                break
        except Exception as e:
            print(f"❌ Frame sender error: {e}")
            # Make sure to mark task as done even if there's an error
            if frame_data is not None:
                try:
                    frame_send_queue.task_done()
                except Exception as e:
                    print(f"❌ Failed to mark frame task as done: {e}")
            break

    print("📡 Frame sender thread stopped")


@torch.no_grad()
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
    """Generate video and push frames immediately to frontend."""
    global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name

    try:
        generation_active = True
        stop_event.clear()
        job_id = generate_timestamp()

        # Start frame sender thread if not already running
        if sender_thread is None or not sender_thread.is_alive():
            sender_thread = Thread(target=frame_sender_worker, daemon=True)
            sender_thread.start()

        # Emit progress updates
        def emit_progress(message, progress):
            try:
                socketio.emit('progress', {
                    'message': message,
                    'progress': progress,
                    'job_id': job_id
                })
            except Exception as e:
                print(f"❌ Failed to emit progress: {e}")

        emit_progress('Starting generation...', 0)

        # Handle VAE decoder switching
        if use_taehv != current_use_taehv:
            emit_progress('Switching VAE decoder...', 2)
            print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
            current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
            # Update pipeline with new VAE decoder
            pipeline.vae = current_vae_decoder

        # Handle FP8 quantization
        if enable_fp8 and not fp8_applied:
            emit_progress('Applying FP8 quantization...', 3)
            print("🔧 Applying FP8 quantization to transformer")
            from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
            quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
            fp8_applied = True

        # Text encoding
        emit_progress('Encoding text prompt...', 8)
        conditional_dict = text_encoder(text_prompts=[prompt])
        for key, value in conditional_dict.items():
            conditional_dict[key] = value.to(dtype=torch.float16)
        if low_memory:
            gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
            move_model_to_device_with_memory_preservation(
                text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)

        # Handle torch.compile if enabled
        torch_compile_applied = enable_torch_compile
        if enable_torch_compile and not models_compiled:
            # Compile transformer and decoder
            transformer.compile(mode="max-autotune-no-cudagraphs")
            if not current_use_taehv and not low_memory and not args.trt:
                current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")

        # Initialize generation
        emit_progress('Initializing generation...', 12)

        rnd = torch.Generator(gpu).manual_seed(seed)
        # all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)

        pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
        pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)

        noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)

        # Generation parameters
        num_blocks = 7
        current_start_frame = 0
        num_input_frames = 0
        all_num_frames = [pipeline.num_frame_per_block] * num_blocks
        if current_use_taehv:
            vae_cache = None
        else:
            vae_cache = ZERO_VAE_CACHE
            for i in range(len(vae_cache)):
                vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)

        total_frames_sent = 0
        generation_start_time = time.time()

        emit_progress('Generating frames... (frontend handles timing)', 15)

        for idx, current_num_frames in enumerate(all_num_frames):
            if not generation_active or stop_event.is_set():
                break

            progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15

            # Special message for first block with torch.compile
            if idx == 0 and torch_compile_applied and not models_compiled:
                emit_progress(
                    f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
                print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
                models_compiled = True
            else:
                emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
                print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")

            block_start_time = time.time()

            noisy_input = noise[:, current_start_frame -
                                num_input_frames:current_start_frame + current_num_frames - num_input_frames]

            # Denoising loop
            denoising_start = time.time()
            for index, current_timestep in enumerate(pipeline.denoising_step_list):
                if not generation_active or stop_event.is_set():
                    break

                timestep = torch.ones([1, current_num_frames], device=noise.device,
                                      dtype=torch.int64) * current_timestep

                if index < len(pipeline.denoising_step_list) - 1:
                    _, denoised_pred = transformer(
                        noisy_image_or_video=noisy_input,
                        conditional_dict=conditional_dict,
                        timestep=timestep,
                        kv_cache=pipeline.kv_cache1,
                        crossattn_cache=pipeline.crossattn_cache,
                        current_start=current_start_frame * pipeline.frame_seq_length
                    )
                    next_timestep = pipeline.denoising_step_list[index + 1]
                    noisy_input = pipeline.scheduler.add_noise(
                        denoised_pred.flatten(0, 1),
                        torch.randn_like(denoised_pred.flatten(0, 1)),
                        next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
                    ).unflatten(0, denoised_pred.shape[:2])
                else:
                    _, denoised_pred = transformer(
                        noisy_image_or_video=noisy_input,
                        conditional_dict=conditional_dict,
                        timestep=timestep,
                        kv_cache=pipeline.kv_cache1,
                        crossattn_cache=pipeline.crossattn_cache,
                        current_start=current_start_frame * pipeline.frame_seq_length
                    )

            if not generation_active or stop_event.is_set():
                break

            denoising_time = time.time() - denoising_start
            print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")

            # Record output
            # all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred

            # Update KV cache for next block
            if idx != len(all_num_frames) - 1:
                transformer(
                    noisy_image_or_video=denoised_pred,
                    conditional_dict=conditional_dict,
                    timestep=torch.zeros_like(timestep),
                    kv_cache=pipeline.kv_cache1,
                    crossattn_cache=pipeline.crossattn_cache,
                    current_start=current_start_frame * pipeline.frame_seq_length,
                )

            # Decode to pixels and send frames immediately
            print(f"🎨 Decoding block {idx+1} to pixels...")
            decode_start = time.time()
            if args.trt:
                all_current_pixels = []
                for i in range(denoised_pred.shape[1]):
                    is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
                        torch.tensor(0.0).cuda().half()
                    outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
                    # outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
                    current_pixels, vae_cache = outputs[0], outputs[1:]
                    print(current_pixels.max(), current_pixels.min())
                    all_current_pixels.append(current_pixels.clone())
                pixels = torch.cat(all_current_pixels, dim=1)
                if idx == 0:
                    pixels = pixels[:, 3:, :, :, :]  # Skip first 3 frames of first block
            else:
                if current_use_taehv:
                    if vae_cache is None:
                        vae_cache = denoised_pred
                    else:
                        denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
                        vae_cache = denoised_pred[:, -3:, :, :, :]
                    pixels = current_vae_decoder.decode(denoised_pred)
                    print(f"denoised_pred shape: {denoised_pred.shape}")
                    print(f"pixels shape: {pixels.shape}")
                    if idx == 0:
                        pixels = pixels[:, 3:, :, :, :]  # Skip first 3 frames of first block
                    else:
                        pixels = pixels[:, 12:, :, :, :]

                else:
                    pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
                    if idx == 0:
                        pixels = pixels[:, 3:, :, :, :]  # Skip first 3 frames of first block

            decode_time = time.time() - decode_start
            print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")

            # Queue frames for non-blocking sending
            block_frames = pixels.shape[1]
            print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
            queue_start = time.time()

            for frame_idx in range(block_frames):
                if not generation_active or stop_event.is_set():
                    break

                frame_tensor = pixels[0, frame_idx].cpu()

                # Queue frame data in non-blocking way
                frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
                total_frames_sent += 1

            queue_time = time.time() - queue_start
            block_time = time.time() - block_start_time
            print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")

            current_start_frame += current_num_frames

        generation_time = time.time() - generation_start_time
        print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")

        # Wait for all frames to be sent before completing
        emit_progress('Waiting for all frames to be sent...', 97)
        print("⏳ Waiting for all frames to be sent...")
        frame_send_queue.join()  # Wait for all queued frames to be processed
        print("✅ All frames sent successfully!")

        generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
        # Final progress update
        emit_progress('Generation complete!', 100)

        try:
            socketio.emit('generation_complete', {
                'message': 'Video generation completed!',
                'total_frames': total_frames_sent,
                'generation_time': f"{generation_time:.2f}s",
                'job_id': job_id
            })
        except Exception as e:
            print(f"❌ Failed to emit generation complete: {e}")

    except Exception as e:
        print(f"❌ Generation failed: {e}")
        try:
            socketio.emit('error', {
                'message': f'Generation failed: {str(e)}',
                'job_id': job_id
            })
        except Exception as e:
            print(f"❌ Failed to emit error: {e}")
    finally:
        generation_active = False
        stop_event.set()

        # Clean up sender thread
        try:
            frame_send_queue.put(None)
        except Exception as e:
            print(f"❌ Failed to put None in frame_send_queue: {e}")


def generate_mp4_from_images(image_directory, output_video_path, fps=24):
    """
    Generate an MP4 video from a directory of images ordered alphabetically.

    :param image_directory: Path to the directory containing images.
    :param output_video_path: Path where the output MP4 will be saved.
    :param fps: Frames per second for the output video.
    """
    global anim_name
    # Construct the ffmpeg command
    cmd = [
        'ffmpeg',
        '-framerate', str(fps),
        '-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'),  # Adjust the pattern if necessary
        '-c:v', 'libx264',
        '-pix_fmt', 'yuv420p',
        output_video_path
    ]
    try:
        subprocess.run(cmd, check=True)
        print(f"Video saved to {output_video_path}")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

def calculate_sha256(data):
    # Convert data to bytes if it's not already
    if isinstance(data, str):
        data = data.encode()
    # Calculate SHA-256 hash
    sha256_hash = hashlib.sha256(data).hexdigest()
    return sha256_hash

# Socket.IO event handlers
@socketio.on('connect')
def handle_connect():
    print('Client connected')
    emit('status', {'message': 'Connected to frontend-buffered demo server'})


@socketio.on('disconnect')
def handle_disconnect():
    print('Client disconnected')


@socketio.on('start_generation')
def handle_start_generation(data):
    global generation_active, frame_number, anim_name, frame_rate

    frame_number = 0
    if generation_active:
        emit('error', {'message': 'Generation already in progress'})
        return

    prompt = data.get('prompt', '')

    seed = data.get('seed', -1)
    if seed==-1:
        seed = random.randint(0, 2**32)

    # Extract words up to the first punctuation or newline
    words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
    if not words_up_to_punctuation:
        words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()

    # Calculate SHA-256 hash of the entire prompt
    sha256_hash = calculate_sha256(prompt)

    # Create anim_name with the extracted words and first 10 characters of the hash
    anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"

    generation_active = True
    generation_start_time = time.time()
    enable_torch_compile = data.get('enable_torch_compile', False)
    enable_fp8 = data.get('enable_fp8', False)
    use_taehv = data.get('use_taehv', False)
    frame_rate = data.get('fps', 6)

    if not prompt:
        emit('error', {'message': 'Prompt is required'})
        return

    # Start generation in background thread
    socketio.start_background_task(generate_video_stream, prompt, seed,
                                   enable_torch_compile, enable_fp8, use_taehv)
    emit('status', {'message': 'Generation started - frames will be sent immediately'})


@socketio.on('stop_generation')
def handle_stop_generation():
    global generation_active, stop_event, frame_send_queue
    generation_active = False
    stop_event.set()

    # Signal sender thread to stop (will be processed after current frames)
    try:
        frame_send_queue.put(None)
    except Exception as e:
        print(f"❌ Failed to put None in frame_send_queue: {e}")

    emit('status', {'message': 'Generation stopped'})

# Web routes


@app.route('/')
def index():
    return render_template('demo.html')


@app.route('/api/status')
def api_status():
    return jsonify({
        'generation_active': generation_active,
        'free_vram_gb': get_cuda_free_memory_gb(gpu),
        'fp8_applied': fp8_applied,
        'torch_compile_applied': torch_compile_applied,
        'current_use_taehv': current_use_taehv
    })


if __name__ == '__main__':
    print(f"🚀 Starting demo on http://{args.host}:{args.port}")
    socketio.run(app, host=args.host, port=args.port, debug=False)


================================================
FILE: demo_utils/constant.py
================================================

import torch


ZERO_VAE_CACHE = [
    torch.zeros(1, 16, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 192, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832)
]

feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names


================================================
FILE: demo_utils/memory.py
================================================
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel

import torch


cpu = torch.device('cpu')
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
gpu_complete_modules = []


class DynamicSwapInstaller:
    @staticmethod
    def _install_module(module: torch.nn.Module, **kwargs):
        original_class = module.__class__
        module.__dict__['forge_backup_original_class'] = original_class

        def hacked_get_attr(self, name: str):
            if '_parameters' in self.__dict__:
                _parameters = self.__dict__['_parameters']
                if name in _parameters:
                    p = _parameters[name]
                    if p is None:
                        return None
                    if p.__class__ == torch.nn.Parameter:
                        return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
                    else:
                        return p.to(**kwargs)
            if '_buffers' in self.__dict__:
                _buffers = self.__dict__['_buffers']
                if name in _buffers:
                    return _buffers[name].to(**kwargs)
            return super(original_class, self).__getattr__(name)

        module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
            '__getattr__': hacked_get_attr,
        })

        return

    @staticmethod
    def _uninstall_module(module: torch.nn.Module):
        if 'forge_backup_original_class' in module.__dict__:
            module.__class__ = module.__dict__.pop('forge_backup_original_class')
        return

    @staticmethod
    def install_model(model: torch.nn.Module, **kwargs):
        for m in model.modules():
            DynamicSwapInstaller._install_module(m, **kwargs)
        return

    @staticmethod
    def uninstall_model(model: torch.nn.Module):
        for m in model.modules():
            DynamicSwapInstaller._uninstall_module(m)
        return


def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
    if hasattr(model, 'scale_shift_table'):
        model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
        return

    for k, p in model.named_modules():
        if hasattr(p, 'weight'):
            p.to(target_device)
            return


def get_cuda_free_memory_gb(device=None):
    if device is None:
        device = gpu

    memory_stats = torch.cuda.memory_stats(device)
    bytes_active = memory_stats['active_bytes.all.current']
    bytes_reserved = memory_stats['reserved_bytes.all.current']
    bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
    bytes_inactive_reserved = bytes_reserved - bytes_active
    bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
    return bytes_total_available / (1024 ** 3)


def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
    print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')

    for m in model.modules():
        if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
            torch.cuda.empty_cache()
            return

        if hasattr(m, 'weight'):
            m.to(device=target_device)

    model.to(device=target_device)
    torch.cuda.empty_cache()
    return


def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
    print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')

    for m in model.modules():
        if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
            torch.cuda.empty_cache()
            return

        if hasattr(m, 'weight'):
            m.to(device=cpu)

    model.to(device=cpu)
    torch.cuda.empty_cache()
    return


def unload_complete_models(*args):
    for m in gpu_complete_modules + list(args):
        m.to(device=cpu)
        print(f'Unloaded {m.__class__.__name__} as complete.')

    gpu_complete_modules.clear()
    torch.cuda.empty_cache()
    return


def load_model_as_complete(model, target_device, unload=True):
    if unload:
        unload_complete_models()

    model.to(device=target_device)
    print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')

    gpu_complete_modules.append(model)
    return


================================================
FILE: demo_utils/taehv.py
================================================
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple

DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))


def conv(n_in, n_out, **kwargs):
    return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)


class Clamp(nn.Module):
    def forward(self, x):
        return torch.tanh(x / 3) * 3


class MemBlock(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
                                  conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
        self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
        self.act = nn.ReLU(inplace=True)

    def forward(self, x, past):
        return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))


class TPool(nn.Module):
    def __init__(self, n_f, stride):
        super().__init__()
        self.stride = stride
        self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)

    def forward(self, x):
        _NT, C, H, W = x.shape
        return self.conv(x.reshape(-1, self.stride * C, H, W))


class TGrow(nn.Module):
    def __init__(self, n_f, stride):
        super().__init__()
        self.stride = stride
        self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)

    def forward(self, x):
        _NT, C, H, W = x.shape
        x = self.conv(x)
        return x.reshape(-1, C, H, W)


def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
    """
    Apply a sequential model with memblocks to the given input.
    Args:
    - model: nn.Sequential of blocks to apply
    - x: input data, of dimensions NTCHW
    - parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
        if False, each timestep will be processed sequentially (slow but uses O(1) memory)
    - show_progress_bar: if True, enables tqdm progressbar display

    Returns NTCHW tensor of output data.
    """
    assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
    N, T, C, H, W = x.shape
    if parallel:
        x = x.reshape(N * T, C, H, W)
        # parallel over input timesteps, iterate over blocks
        for b in tqdm(model, disable=not show_progress_bar):
            if isinstance(b, MemBlock):
                NT, C, H, W = x.shape
                T = NT // N
                _x = x.reshape(N, T, C, H, W)
                mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
                x = b(x, mem)
            else:
                x = b(x)
        NT, C, H, W = x.shape
        T = NT // N
        x = x.view(N, T, C, H, W)
    else:
        # TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
        # need to fix :(
        out = []
        # iterate over input timesteps and also iterate over blocks.
        # because of the cursed TPool/TGrow blocks, this is not a nested loop,
        # it's actually a ***graph traversal*** problem! so let's make a queue
        work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
        # in addition to manually managing our queue, we also need to manually manage our progressbar.
        # we'll update it for every source node that we consume.
        progress_bar = tqdm(range(T), disable=not show_progress_bar)
        # we'll also need a separate addressable memory per node as well
        mem = [None] * len(model)
        while work_queue:
            xt, i = work_queue.pop(0)
            if i == 0:
                # new source node consumed
                progress_bar.update(1)
            if i == len(model):
                # reached end of the graph, append result to output list
                out.append(xt)
            else:
                # fetch the block to process
                b = model[i]
                if isinstance(b, MemBlock):
                    # mem blocks are simple since we're visiting the graph in causal order
                    if mem[i] is None:
                        xt_new = b(xt, xt * 0)
                        mem[i] = xt
                    else:
                        xt_new = b(xt, mem[i])
                        mem[i].copy_(xt)  # inplace might reduce mysterious pytorch memory allocations? doesn't help though
                    # add successor to work queue
                    work_queue.insert(0, TWorkItem(xt_new, i + 1))
                elif isinstance(b, TPool):
                    # pool blocks are miserable
                    if mem[i] is None:
                        mem[i] = []  # pool memory is itself a queue of inputs to pool
                    mem[i].append(xt)
                    if len(mem[i]) > b.stride:
                        # pool mem is in invalid state, we should have pooled before this
                        raise ValueError("???")
                    elif len(mem[i]) < b.stride:
                        # pool mem is not yet full, go back to processing the work queue
                        pass
                    else:
                        # pool mem is ready, run the pool block
                        N, C, H, W = xt.shape
                        xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
                        # reset the pool mem
                        mem[i] = []
                        # add successor to work queue
                        work_queue.insert(0, TWorkItem(xt, i + 1))
                elif isinstance(b, TGrow):
                    xt = b(xt)
                    NT, C, H, W = xt.shape
                    # each tgrow has multiple successor nodes
                    for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
                        # add successor to work queue
                        work_queue.insert(0, TWorkItem(xt_next, i + 1))
                else:
                    # normal block with no funny business
                    xt = b(xt)
                    # add successor to work queue
                    work_queue.insert(0, TWorkItem(xt, i + 1))
        progress_bar.close()
        x = torch.stack(out, 1)
    return x


class TAEHV(nn.Module):
    latent_channels = 16
    image_channels = 3

    def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
        """Initialize pretrained TAEHV from the given checkpoint.

        Arg:
            checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
            decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
            decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
        """
        super().__init__()
        self.encoder = nn.Sequential(
            conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
            TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
            TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
            TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
            conv(64, TAEHV.latent_channels),
        )
        n_f = [256, 128, 64, 64]
        self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
        self.decoder = nn.Sequential(
            Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
            MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
                scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
            MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
                scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
            MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
                scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
            nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
        )
        if checkpoint_path is not None:
            self.load_state_dict(self.patch_tgrow_layers(torch.load(
                checkpoint_path, map_location="cpu", weights_only=True)))

    def patch_tgrow_layers(self, sd):
        """Patch TGrow layers to use a smaller kernel if needed.

        Args:
            sd: state dict to patch
        """
        new_sd = self.state_dict()
        for i, layer in enumerate(self.decoder):
            if isinstance(layer, TGrow):
                key = f"decoder.{i}.conv.weight"
                if sd[key].shape[0] > new_sd[key].shape[0]:
                    # take the last-timestep output channels
                    sd[key] = sd[key][-new_sd[key].shape[0]:]
        return sd

    def encode_video(self, x, parallel=True, show_progress_bar=True):
        """Encode a sequence of frames.

        Args:
            x: input NTCHW RGB (C=3) tensor with values in [0, 1].
            parallel: if True, all frames will be processed at once.
              (this is faster but may require more memory).
              if False, frames will be processed sequentially.
        Returns NTCHW latent tensor with ~Gaussian values.
        """
        return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)

    def decode_video(self, x, parallel=True, show_progress_bar=False):
        """Decode a sequence of frames.

        Args:
            x: input NTCHW latent (C=12) tensor with ~Gaussian values.
            parallel: if True, all frames will be processed at once.
              (this is faster but may require more memory).
              if False, frames will be processed sequentially.
        Returns NTCHW RGB tensor with ~[0, 1] values.
        """
        x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
        # return x[:, self.frames_to_trim:]
        return x

    def forward(self, x):
        return self.c(x)


@torch.no_grad()
def main():
    """Run TAEHV roundtrip reconstruction on the given video paths."""
    import os
    import sys
    import cv2  # no highly esteemed deed is commemorated here

    class VideoTensorReader:
        def __init__(self, video_file_path):
            self.cap = cv2.VideoCapture(video_file_path)
            assert self.cap.isOpened(), f"Could not load {video_file_path}"
            self.fps = self.cap.get(cv2.CAP_PROP_FPS)

        def __iter__(self):
            return self

        def __next__(self):
            ret, frame = self.cap.read()
            if not ret:
                self.cap.release()
                raise StopIteration  # End of video or error
            return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1)  # BGR HWC -> RGB CHW

    class VideoTensorWriter:
        def __init__(self, video_file_path, width_height, fps=30):
            self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
            assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"

        def write(self, frame_tensor):
            assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
            self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
                              cv2.COLOR_RGB2BGR))  # RGB CHW -> BGR HWC

        def __del__(self):
            if hasattr(self, 'writer'):
                self.writer.release()

    dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    dtype = torch.float16
    checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
    checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
    print(
        f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
    taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
    for video_path in sys.argv[1:]:
        print(f"Processing {video_path}...")
        video_in = VideoTensorReader(video_path)
        video = torch.stack(list(video_in), 0)[None]
        vid_dev = video.to(dev, dtype).div_(255.0)
        # convert to device tensor
        if video.numel() < 100_000_000:
            print(f"  {video_path} seems small enough, will process all frames in parallel")
            # convert to device tensor
            vid_enc = taehv.encode_video(vid_dev)
            print(f"  Encoded {video_path} -> {vid_enc.shape}. Decoding...")
            vid_dec = taehv.decode_video(vid_enc)
            print(f"  Decoded {video_path} -> {vid_dec.shape}")
        else:
            print(f"  {video_path} seems large, will process each frame sequentially")
            # convert to device tensor
            vid_enc = taehv.encode_video(vid_dev, parallel=False)
            print(f"  Encoded {video_path} -> {vid_enc.shape}. Decoding...")
            vid_dec = taehv.decode_video(vid_enc, parallel=False)
            print(f"  Decoded {video_path} -> {vid_dec.shape}")
        video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
        video_out = VideoTensorWriter(
            video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
        for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
            video_out.write(frame)
        print(f"  Saved to {video_out_path}")


if __name__ == "__main__":
    main()


================================================
FILE: demo_utils/utils.py
================================================
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
# Apache-2.0 License
# By lllyasviel

import os
import cv2
import json
import random
import glob
import torch
import einops
import numpy as np
import datetime
import torchvision

from PIL import Image


def min_resize(x, m):
    if x.shape[0] < x.shape[1]:
        s0 = m
        s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
    else:
        s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
        s1 = m
    new_max = max(s1, s0)
    raw_max = max(x.shape[0], x.shape[1])
    if new_max < raw_max:
        interpolation = cv2.INTER_AREA
    else:
        interpolation = cv2.INTER_LANCZOS4
    y = cv2.resize(x, (s1, s0), interpolation=interpolation)
    return y


def d_resize(x, y):
    H, W, C = y.shape
    new_min = min(H, W)
    raw_min = min(x.shape[0], x.shape[1])
    if new_min < raw_min:
        interpolation = cv2.INTER_AREA
    else:
        interpolation = cv2.INTER_LANCZOS4
    y = cv2.resize(x, (W, H), interpolation=interpolation)
    return y


def resize_and_center_crop(image, target_width, target_height):
    if target_height == image.shape[0] and target_width == image.shape[1]:
        return image

    pil_image = Image.fromarray(image)
    original_width, original_height = pil_image.size
    scale_factor = max(target_width / original_width, target_height / original_height)
    resized_width = int(round(original_width * scale_factor))
    resized_height = int(round(original_height * scale_factor))
    resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
    left = (resized_width - target_width) / 2
    top = (resized_height - target_height) / 2
    right = (resized_width + target_width) / 2
    bottom = (resized_height + target_height) / 2
    cropped_image = resized_image.crop((left, top, right, bottom))
    return np.array(cropped_image)


def resize_and_center_crop_pytorch(image, target_width, target_height):
    B, C, H, W = image.shape

    if H == target_height and W == target_width:
        return image

    scale_factor = max(target_width / W, target_height / H)
    resized_width = int(round(W * scale_factor))
    resized_height = int(round(H * scale_factor))

    resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)

    top = (resized_height - target_height) // 2
    left = (resized_width - target_width) // 2
    cropped = resized[:, :, top:top + target_height, left:left + target_width]

    return cropped


def resize_without_crop(image, target_width, target_height):
    if target_height == image.shape[0] and target_width == image.shape[1]:
        return image

    pil_image = Image.fromarray(image)
    resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
    return np.array(resized_image)


def just_crop(image, w, h):
    if h == image.shape[0] and w == image.shape[1]:
        return image

    original_height, original_width = image.shape[:2]
    k = min(original_height / h, original_width / w)
    new_width = int(round(w * k))
    new_height = int(round(h * k))
    x_start = (original_width - new_width) // 2
    y_start = (original_height - new_height) // 2
    cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
    return cropped_image


def write_to_json(data, file_path):
    temp_file_path = file_path + ".tmp"
    with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
        json.dump(data, temp_file, indent=4)
    os.replace(temp_file_path, file_path)
    return


def read_from_json(file_path):
    with open(file_path, 'rt', encoding='utf-8') as file:
        data = json.load(file)
    return data


def get_active_parameters(m):
    return {k: v for k, v in m.named_parameters() if v.requires_grad}


def cast_training_params(m, dtype=torch.float32):
    result = {}
    for n, param in m.named_parameters():
        if param.requires_grad:
            param.data = param.to(dtype)
            result[n] = param
    return result


def separate_lora_AB(parameters, B_patterns=None):
    parameters_normal = {}
    parameters_B = {}

    if B_patterns is None:
        B_patterns = ['.lora_B.', '__zero__']

    for k, v in parameters.items():
        if any(B_pattern in k for B_pattern in B_patterns):
            parameters_B[k] = v
        else:
            parameters_normal[k] = v

    return parameters_normal, parameters_B


def set_attr_recursive(obj, attr, value):
    attrs = attr.split(".")
    for name in attrs[:-1]:
        obj = getattr(obj, name)
    setattr(obj, attrs[-1], value)
    return


def print_tensor_list_size(tensors):
    total_size = 0
    total_elements = 0

    if isinstance(tensors, dict):
        tensors = tensors.values()

    for tensor in tensors:
        total_size += tensor.nelement() * tensor.element_size()
        total_elements += tensor.nelement()

    total_size_MB = total_size / (1024 ** 2)
    total_elements_B = total_elements / 1e9

    print(f"Total number of tensors: {len(tensors)}")
    print(f"Total size of tensors: {total_size_MB:.2f} MB")
    print(f"Total number of parameters: {total_elements_B:.3f} billion")
    return


@torch.no_grad()
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
    batch_size = a.size(0)

    if b is None:
        b = torch.zeros_like(a)

    if mask_a is None:
        mask_a = torch.rand(batch_size) < probability_a

    mask_a = mask_a.to(a.device)
    mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
    result = torch.where(mask_a, a, b)
    return result


@torch.no_grad()
def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module


@torch.no_grad()
def supress_lower_channels(m, k, alpha=0.01):
    data = m.weight.data.clone()

    assert int(data.shape[1]) >= k

    data[:, :k] = data[:, :k] * alpha
    m.weight.data = data.contiguous().clone()
    return m


def freeze_module(m):
    if not hasattr(m, '_forward_inside_frozen_module'):
        m._forward_inside_frozen_module = m.forward
    m.requires_grad_(False)
    m.forward = torch.no_grad()(m.forward)
    return m


def get_latest_safetensors(folder_path):
    safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))

    if not safetensors_files:
        raise ValueError('No file to resume!')

    latest_file = max(safetensors_files, key=os.path.getmtime)
    latest_file = os.path.abspath(os.path.realpath(latest_file))
    return latest_file


def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
    tags = tags_str.split(', ')
    tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
    prompt = ', '.join(tags)
    return prompt


def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
    numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
    if round_to_int:
        numbers = np.round(numbers).astype(int)
    return numbers.tolist()


def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
    edges = np.linspace(0, 1, n + 1)
    points = np.random.uniform(edges[:-1], edges[1:])
    numbers = inclusive + (exclusive - inclusive) * points
    if round_to_int:
        numbers = np.round(numbers).astype(int)
    return numbers.tolist()


def soft_append_bcthw(history, current, overlap=0):
    if overlap <= 0:
        return torch.cat([history, current], dim=2)

    assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
    assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"

    weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
    blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
    output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)

    return output.to(history)


def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
    b, c, t, h, w = x.shape

    per_row = b
    for p in [6, 5, 4, 3, 2]:
        if b % p == 0:
            per_row = p
            break

    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
    torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
    return x


def save_bcthw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def save_bchw_as_png(x, output_filename):
    os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
    x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
    x = x.detach().cpu().to(torch.uint8)
    x = einops.rearrange(x, 'b c h w -> c h (b w)')
    torchvision.io.write_png(x, output_filename)
    return output_filename


def add_tensors_with_padding(tensor1, tensor2):
    if tensor1.shape == tensor2.shape:
        return tensor1 + tensor2

    shape1 = tensor1.shape
    shape2 = tensor2.shape

    new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))

    padded_tensor1 = torch.zeros(new_shape)
    padded_tensor2 = torch.zeros(new_shape)

    padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
    padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2

    result = padded_tensor1 + padded_tensor2
    return result


def print_free_mem():
    torch.cuda.empty_cache()
    free_mem, total_mem = torch.cuda.mem_get_info(0)
    free_mem_mb = free_mem / (1024 ** 2)
    total_mem_mb = total_mem / (1024 ** 2)
    print(f"Free memory: {free_mem_mb:.2f} MB")
    print(f"Total memory: {total_mem_mb:.2f} MB")
    return


def print_gpu_parameters(device, state_dict, log_count=1):
    summary = {"device": device, "keys_count": len(state_dict)}

    logged_params = {}
    for i, (key, tensor) in enumerate(state_dict.items()):
        if i >= log_count:
            break
        logged_params[key] = tensor.flatten()[:3].tolist()

    summary["params"] = logged_params

    print(str(summary))
    return


def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
    from PIL import Image, ImageDraw, ImageFont

    txt = Image.new("RGB", (width, height), color="white")
    draw = ImageDraw.Draw(txt)
    font = ImageFont.truetype(font_path, size=size)

    if text == '':
        return np.array(txt)

    # Split text into lines that fit within the image width
    lines = []
    words = text.split()
    current_line = words[0]

    for word in words[1:]:
        line_with_word = f"{current_line} {word}"
        if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
            current_line = line_with_word
        else:
            lines.append(current_line)
            current_line = word

    lines.append(current_line)

    # Draw the text line by line
    y = 0
    line_height = draw.textbbox((0, 0), "A", font=font)[3]

    for line in lines:
        if y + line_height > height:
            break  # stop drawing if the next line will be outside the image
        draw.text((0, y), line, fill="black", font=font)
        y += line_height

    return np.array(txt)


def blue_mark(x):
    x = x.copy()
    c = x[:, :, 2]
    b = cv2.blur(c, (9, 9))
    x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
    return x


def green_mark(x):
    x = x.copy()
    x[:, :, 2] = -1
    x[:, :, 0] = -1
    return x


def frame_mark(x):
    x = x.copy()
    x[:64] = -1
    x[-64:] = -1
    x[:, :8] = 1
    x[:, -8:] = 1
    return x


@torch.inference_mode()
def pytorch2numpy(imgs):
    results = []
    for x in imgs:
        y = x.movedim(0, -1)
        y = y * 127.5 + 127.5
        y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
        results.append(y)
    return results


@torch.inference_mode()
def numpy2pytorch(imgs):
    h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
    h = h.movedim(-1, 1)
    return h


@torch.no_grad()
def duplicate_prefix_to_suffix(x, count, zero_out=False):
    if zero_out:
        return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
    else:
        return torch.cat([x, x[:count]], dim=0)


def weighted_mse(a, b, weight):
    return torch.mean(weight.float() * (a.float() - b.float()) ** 2)


def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
    x = (x - x_min) / (x_max - x_min)
    x = max(0.0, min(x, 1.0))
    x = x ** sigma
    return y_min + x * (y_max - y_min)


def expand_to_dims(x, target_dims):
    return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))


def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
    if tensor is None:
        return None

    first_dim = tensor.shape[0]

    if first_dim == batch_size:
        return tensor

    if batch_size % first_dim != 0:
        raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")

    repeat_times = batch_size // first_dim

    return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))


def dim5(x):
    return expand_to_dims(x, 5)


def dim4(x):
    return expand_to_dims(x, 4)


def dim3(x):
    return expand_to_dims(x, 3)


def crop_or_pad_yield_mask(x, length):
    B, F, C = x.shape
    device = x.device
    dtype = x.dtype

    if F < length:
        y = torch.zeros((B, length, C), dtype=dtype, device=device)
        mask = torch.zeros((B, length), dtype=torch.bool, device=device)
        y[:, :F, :] = x
        mask[:, :F] = True
        return y, mask

    return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)


def extend_dim(x, dim, minimal_length, zero_pad=False):
    original_length = int(x.shape[dim])

    if original_length >= minimal_length:
        return x

    if zero_pad:
        padding_shape = list(x.shape)
        padding_shape[dim] = minimal_length - original_length
        padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
    else:
        idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
        last_element = x[idx]
        padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)

    return torch.cat([x, padding], dim=dim)


def lazy_positional_encoding(t, repeats=None):
    if not isinstance(t, list):
        t = [t]

    from diffusers.models.embeddings import get_timestep_embedding

    te = torch.tensor(t)
    te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)

    if repeats is None:
        return te

    te = te[:, None, :].expand(-1, repeats, -1)

    return te


def state_dict_offset_merge(A, B, C=None):
    result = {}
    keys = A.keys()

    for key in keys:
        A_value = A[key]
        B_value = B[key].to(A_value)

        if C is None:
            result[key] = A_value + B_value
        else:
            C_value = C[key].to(A_value)
            result[key] = A_value + B_value - C_value

    return result


def state_dict_weighted_merge(state_dicts, weights):
    if len(state_dicts) != len(weights):
        raise ValueError("Number of state dictionaries must match number of weights")

    if not state_dicts:
        return {}

    total_weight = sum(weights)

    if total_weight == 0:
        raise ValueError("Sum of weights cannot be zero")

    normalized_weights = [w / total_weight for w in weights]

    keys = state_dicts[0].keys()
    result = {}

    for key in keys:
        result[key] = state_dicts[0][key] * normalized_weights[0]

        for i in range(1, len(state_dicts)):
            state_dict_value = state_dicts[i][key].to(result[key])
            result[key] += state_dict_value * normalized_weights[i]

    return result


def group_files_by_folder(all_files):
    grouped_files = {}

    for file in all_files:
        folder_name = os.path.basename(os.path.dirname(file))
        if folder_name not in grouped_files:
            grouped_files[folder_name] = []
        grouped_files[folder_name].append(file)

    list_of_lists = list(grouped_files.values())
    return list_of_lists


def generate_timestamp():
    now = datetime.datetime.now()
    timestamp = now.strftime('%y%m%d_%H%M%S')
    milliseconds = f"{int(now.microsecond / 1000):03d}"
    random_number = random.randint(0, 9999)
    return f"{timestamp}_{milliseconds}_{random_number}"


def write_PIL_image_with_png_info(image, metadata, path):
    from PIL.PngImagePlugin import PngInfo

    png_info = PngInfo()
    for key, value in metadata.items():
        png_info.add_text(key, value)

    image.save(path, "PNG", pnginfo=png_info)
    return image


def torch_safe_save(content, path):
    torch.save(content, path + '_tmp')
    os.replace(path + '_tmp', path)
    return path


def move_optimizer_to_device(optimizer, device):
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)


================================================
FILE: demo_utils/vae.py
================================================
from typing import List
from einops import rearrange
import tensorrt as trt
import torch
import torch.nn as nn

from demo_utils.constant import ALL_INPUTS_NAMES, ZERO_VAE_CACHE
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, Upsample

CACHE_T = 2


class ResidualBlock(nn.Module):

    def __init__(self, in_dim, out_dim, dropout=0.0):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        # layers
        self.residual = nn.Sequential(
            RMS_norm(in_dim, images=False), nn.SiLU(),
            CausalConv3d(in_dim, out_dim, 3, padding=1),
            RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
            CausalConv3d(out_dim, out_dim, 3, padding=1))
        self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
            if in_dim != out_dim else nn.Identity()

    def forward(self, x, feat_cache_1, feat_cache_2):
        h = self.shortcut(x)
        feat_cache = feat_cache_1
        out_feat_cache = []
        for layer in self.residual:
            if isinstance(layer, CausalConv3d):
                cache_x = x[:, :, -CACHE_T:, :, :].clone()
                if cache_x.shape[2] < 2 and feat_cache is not None:
                    # cache last frame of last two chunk
                    cache_x = torch.cat([
                        feat_cache[:, :, -1, :, :].unsqueeze(2).to(
                            cache_x.device), cache_x
                    ],
                        dim=2)
                x = layer(x, feat_cache)
                out_feat_cache.append(cache_x)
                feat_cache = feat_cache_2
            else:
                x = layer(x)
        return x + h, *out_feat_cache


class Resample(nn.Module):

    def __init__(self, dim, mode):
        assert mode in ('none', 'upsample2d', 'upsample3d')
        super().__init__()
        self.dim = dim
        self.mode = mode

        # layers
        if mode == 'upsample2d':
            self.resample = nn.Sequential(
                Upsample(scale_factor=(2., 2.), mode='nearest'),
                nn.Conv2d(dim, dim // 2, 3, padding=1))
        elif mode == 'upsample3d':
            self.resample = nn.Sequential(
                Upsample(scale_factor=(2., 2.), mode='nearest'),
                nn.Conv2d(dim, dim // 2, 3, padding=1))
            self.time_conv = CausalConv3d(
                dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
        else:
            self.resample = nn.Identity()

    def forward(self, x, is_first_frame, feat_cache):
        if self.mode == 'upsample3d':
            b, c, t, h, w = x.size()
            # x, out_feat_cache = torch.cond(
            #     is_first_frame,
            #     lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
            #     lambda: self.temporal_conv(x, feat_cache),
            # )
            # x, out_feat_cache = torch.cond(
            #     is_first_frame,
            #     lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
            #     lambda: self.temporal_conv(x, feat_cache),
            # )
            x, out_feat_cache = self.temporal_conv(x, is_first_frame, feat_cache)
            out_feat_cache = torch.cond(
                is_first_frame,
                lambda: feat_cache.clone().contiguous(),
                lambda: out_feat_cache.clone().contiguous(),
            )
            # if is_first_frame:
            #     x = torch.cat([torch.zeros_like(x), x], dim=2)
            #     out_feat_cache = feat_cache.clone()
            # else:
            #     x, out_feat_cache = self.temporal_conv(x, feat_cache)
        else:
            out_feat_cache = None
        t = x.shape[2]
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        x = self.resample(x)
        x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
        return x, out_feat_cache

    def temporal_conv(self, x, is_first_frame, feat_cache):
        b, c, t, h, w = x.size()
        cache_x = x[:, :, -CACHE_T:, :, :].clone()
        if cache_x.shape[2] < 2 and feat_cache is not None:
            cache_x = torch.cat([
                torch.zeros_like(cache_x),
                cache_x
            ], dim=2)
        x = torch.cond(
            is_first_frame,
            lambda: torch.cat([torch.zeros_like(x), x], dim=1).contiguous(),
            lambda: self.time_conv(x, feat_cache).contiguous(),
        )
        # x = self.time_conv(x, feat_cache)
        out_feat_cache = cache_x

        x = x.reshape(b, 2, c, t, h, w)
        x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
                        3)
        x = x.reshape(b, c, t * 2, h, w)
        return x.contiguous(), out_feat_cache.contiguous()

    def init_weight(self, conv):
        conv_weight = conv.weight
        nn.init.zeros_(conv_weight)
        c1, c2, t, h, w = conv_weight.size()
        one_matrix = torch.eye(c1, c2)
        init_matrix = one_matrix
        nn.init.zeros_(conv_weight)
        # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
        conv_weight.data[:, :, 1, 0, 0] = init_matrix  # * 0.5
        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def init_weight2(self, conv):
        conv_weight = conv.weight.data
        nn.init.zeros_(conv_weight)
        c1, c2, t, h, w = conv_weight.size()
        init_matrix = torch.eye(c1 // 2, c2)
        # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
        conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
        conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)


class VAEDecoderWrapperSingle(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = VAEDecoder3d()
        mean = [
            -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
            0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
        ]
        std = [
            2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
            3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
        ]
        self.mean = torch.tensor(mean, dtype=torch.float32)
        self.std = torch.tensor(std, dtype=torch.float32)
        self.z_dim = 16
        self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)

    def forward(
            self,
            z: torch.Tensor,
            is_first_frame: torch.Tensor,
            *feat_cache: List[torch.Tensor]
    ):
        # from [batch_size, num_frames, num_channels, height, width]
        # to [batch_size, num_channels, num_frames, height, width]
        z = z.permute(0, 2, 1, 3, 4)
        assert z.shape[2] == 1
        feat_cache = list(feat_cache)
        is_first_frame = is_first_frame.bool()

        device, dtype = z.device, z.dtype
        scale = [self.mean.to(device=device, dtype=dtype),
                 1.0 / self.std.to(device=device, dtype=dtype)]

        if isinstance(scale[0], torch.Tensor):
            z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
                1, self.z_dim, 1, 1, 1)
        else:
            z = z / scale[1] + scale[0]
        x = self.conv2(z)
        out, feat_cache = self.decoder(x, is_first_frame, feat_cache=feat_cache)
        out = out.clamp_(-1, 1)
        # from [batch_size, num_channels, num_frames, height, width]
        # to [batch_size, num_frames, num_channels, height, width]
        out = out.permute(0, 2, 1, 3, 4)
        return out, feat_cache


class VAEDecoder3d(nn.Module):
    def __init__(self,
                 dim=96,
                 z_dim=16,
                 dim_mult=[1, 2, 4, 4],
                 num_res_blocks=2,
                 attn_scales=[],
                 temperal_upsample=[True, True, False],
                 dropout=0.0):
        super().__init__()
        self.dim = dim
        self.z_dim = z_dim
        self.dim_mult = dim_mult
        self.num_res_blocks = num_res_blocks
        self.attn_scales = attn_scales
        self.temperal_upsample = temperal_upsample
        self.cache_t = 2
        self.decoder_conv_num = 32

        # dimensions
        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
        scale = 1.0 / 2**(len(dim_mult) - 2)

        # init block
        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)

        # middle blocks
        self.middle = nn.Sequential(
            ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
            ResidualBlock(dims[0], dims[0], dropout))

        # upsample blocks
        upsamples = []
        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            # residual (+attention) blocks
            if i == 1 or i == 2 or i == 3:
                in_dim = in_dim // 2
            for _ in range(num_res_blocks + 1):
                upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
                if scale in attn_scales:
                    upsamples.append(AttentionBlock(out_dim))
                in_dim = out_dim

            # upsample block
            if i != len(dim_mult) - 1:
                mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
                upsamples.append(Resample(out_dim, mode=mode))
                scale *= 2.0
        self.upsamples = nn.Sequential(*upsamples)

        # output blocks
        self.head = nn.Sequential(
            RMS_norm(out_dim, images=False), nn.SiLU(),
            CausalConv3d(out_dim, 3, 3, padding=1))

    def forward(
            self,
            x: torch.Tensor,
            is_first_frame: torch.Tensor,
            feat_cache: List[torch.Tensor]
    ):
        idx = 0
        out_feat_cache = []

        # conv1
        cache_x = x[:, :, -self.cache_t:, :, :].clone()
        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
            # cache last frame of last two chunk
            cache_x = torch.cat([
                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                    cache_x.device), cache_x
            ],
                dim=2)
        x = self.conv1(x, feat_cache[idx])
        out_feat_cache.append(cache_x)
        idx += 1

        # middle
        for layer in self.middle:
            if isinstance(layer, ResidualBlock) and feat_cache is not None:
                x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
                idx += 2
                out_feat_cache.append(out_feat_cache_1)
                out_feat_cache.append(out_feat_cache_2)
            else:
                x = layer(x)

        # upsamples
        for layer in self.upsamples:
            if isinstance(layer, Resample):
                x, cache_x = layer(x, is_first_frame, feat_cache[idx])
                if cache_x is not None:
                    out_feat_cache.append(cache_x)
                    idx += 1
            else:
                x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
                idx += 2
                out_feat_cache.append(out_feat_cache_1)
                out_feat_cache.append(out_feat_cache_2)

        # head
        for layer in self.head:
            if isinstance(layer, CausalConv3d) and feat_cache is not None:
                cache_x = x[:, :, -self.cache_t:, :, :].clone()
                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
                    # cache last frame of last two chunk
                    cache_x = torch.cat([
                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                            cache_x.device), cache_x
                    ],
                        dim=2)
                x = layer(x, feat_cache[idx])
                out_feat_cache.append(cache_x)
                idx += 1
            else:
                x = layer(x)
        return x, out_feat_cache


class VAETRTWrapper():
    def __init__(self):
        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
        with open("checkpoints/vae_decoder_int8.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
            self.engine: trt.ICudaEngine = rt.deserialize_cuda_engine(f.read())

        self.context: trt.IExecutionContext = self.engine.create_execution_context()
        self.stream = torch.cuda.current_stream().cuda_stream

        # ──────────────────────────────
        # 2️⃣  Feed the engine with tensors
        #     (name-based API in TRT ≥10)
        # ──────────────────────────────
        self.dtype_map = {
            trt.float32: torch.float32,
            trt.float16: torch.float16,
            trt.int8: torch.int8,
            trt.int32: torch.int32,
        }
        test_input = torch.zeros(1, 16, 1, 60, 104).cuda().half()
        is_first_frame = torch.tensor(1.0).cuda().half()
        test_cache_inputs = [c.cuda().half() for c in ZERO_VAE_CACHE]
        test_inputs = [test_input, is_first_frame] + test_cache_inputs

        # keep references so buffers stay alive
        self.device_buffers, self.outputs = {}, []

        # ---- inputs ----
        for i, name in enumerate(ALL_INPUTS_NAMES):
            tensor, scale = test_inputs[i], 1 / 127
            tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)

            # dynamic shapes
            if -1 in self.engine.get_tensor_shape(name):
                # new API :contentReference[oaicite:0]{index=0}
                self.context.set_input_shape(name, tuple(tensor.shape))

            # replaces bindings[] :contentReference[oaicite:1]{index=1}
            self.context.set_tensor_address(name, int(tensor.data_ptr()))
            self.device_buffers[name] = tensor                             # keep pointer alive

        # ---- (after all input shapes are known) infer output shapes ----
        # propagates shapes :contentReference[oaicite:2]{index=2}
        self.context.infer_shapes()

        for i in range(self.engine.num_io_tensors):
            name = self.engine.get_tensor_name(i)
            # replaces binding_is_input :contentReference[oaicite:3]{index=3}
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
                shape = tuple(self.context.get_tensor_shape(name))
                dtype = self.dtype_map[self.engine.get_tensor_dtype(name)]
                out = torch.empty(shape, dtype=dtype, device="cuda").contiguous()

                self.context.set_tensor_address(name, int(out.data_ptr()))
                self.outputs.append(out)
                self.device_buffers[name] = out

    # helper to quant-convert on the fly
    def quantize_if_needed(self, t, expected_dtype, scale):
        if expected_dtype == trt.int8 and t.dtype != torch.int8:
            t = torch.clamp((t / scale).round(), -128, 127).to(torch.int8).contiguous()
        return t                            # keep pointer alive

    def forward(self, *test_inputs):
        for i, name in enumerate(ALL_INPUTS_NAMES):
            tensor, scale = test_inputs[i], 1 / 127
            tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
            self.context.set_tensor_address(name, int(tensor.data_ptr()))
            self.device_buffers[name] = tensor

        self.context.execute_async_v3(stream_handle=self.stream)
        torch.cuda.current_stream().synchronize()
        return self.outputs


================================================
FILE: demo_utils/vae_block3.py
================================================
from typing import List
from einops import rearrange
import torch
import torch.nn as nn

from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, ResidualBlock, Upsample


class Resample(nn.Module):

    def __init__(self, dim, mode):
        assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
                        'downsample3d')
        super().__init__()
        self.dim = dim
        self.mode = mode
        self.cache_t = 2

        # layers
        if mode == 'upsample2d':
            self.resample = nn.Sequential(
                Upsample(scale_factor=(2., 2.), mode='nearest'),
                nn.Conv2d(dim, dim // 2, 3, padding=1))
        elif mode == 'upsample3d':
            self.resample = nn.Sequential(
                Upsample(scale_factor=(2., 2.), mode='nearest'),
                nn.Conv2d(dim, dim // 2, 3, padding=1))
            self.time_conv = CausalConv3d(
                dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))

        elif mode == 'downsample2d':
            self.resample = nn.Sequential(
                nn.ZeroPad2d((0, 1, 0, 1)),
                nn.Conv2d(dim, dim, 3, stride=(2, 2)))
        elif mode == 'downsample3d':
            self.resample = nn.Sequential(
                nn.ZeroPad2d((0, 1, 0, 1)),
                nn.Conv2d(dim, dim, 3, stride=(2, 2)))
            self.time_conv = CausalConv3d(
                dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))

        else:
            self.resample = nn.Identity()

    def forward(self, x, feat_cache=None, feat_idx=[0]):
        b, c, t, h, w = x.size()
        if self.mode == 'upsample3d':
            if feat_cache is not None:
                idx = feat_idx[0]
                if feat_cache[idx] is None:
                    feat_cache[idx] = 'Rep'
                    feat_idx[0] += 1
                else:

                    cache_x = x[:, :, -self.cache_t:, :, :].clone()
                    if cache_x.shape[2] < 2 and feat_cache[
                            idx] is not None and feat_cache[idx] != 'Rep':
                        # cache last frame of last two chunk
                        cache_x = torch.cat([
                            feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                                cache_x.device), cache_x
                        ],
                            dim=2)
                    if cache_x.shape[2] < 2 and feat_cache[
                            idx] is not None and feat_cache[idx] == 'Rep':
                        cache_x = torch.cat([
                            torch.zeros_like(cache_x).to(cache_x.device),
                            cache_x
                        ],
                            dim=2)
                    if feat_cache[idx] == 'Rep':
                        x = self.time_conv(x)
                    else:
                        x = self.time_conv(x, feat_cache[idx])
                    feat_cache[idx] = cache_x
                    feat_idx[0] += 1

                    x = x.reshape(b, 2, c, t, h, w)
                    x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
                                    3)
                    x = x.reshape(b, c, t * 2, h, w)
        t = x.shape[2]
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        x = self.resample(x)
        x = rearrange(x, '(b t) c h w -> b c t h w', t=t)

        if self.mode == 'downsample3d':
            if feat_cache is not None:
                idx = feat_idx[0]
                if feat_cache[idx] is None:
                    feat_cache[idx] = x.clone()
                    feat_idx[0] += 1
                else:

                    cache_x = x[:, :, -1:, :, :].clone()
                    # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
                    #     # cache last frame of last two chunk
                    #     cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)

                    x = self.time_conv(
                        torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
                    feat_cache[idx] = cache_x
                    feat_idx[0] += 1
        return x

    def init_weight(self, conv):
        conv_weight = conv.weight
        nn.init.zeros_(conv_weight)
        c1, c2, t, h, w = conv_weight.size()
        one_matrix = torch.eye(c1, c2)
        init_matrix = one_matrix
        nn.init.zeros_(conv_weight)
        # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
        conv_weight.data[:, :, 1, 0, 0] = init_matrix  # * 0.5
        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def init_weight2(self, conv):
        conv_weight = conv.weight.data
        nn.init.zeros_(conv_weight)
        c1, c2, t, h, w = conv_weight.size()
        init_matrix = torch.eye(c1 // 2, c2)
        # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
        conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
        conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)


class VAEDecoderWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = VAEDecoder3d()
        mean = [
            -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
            0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
        ]
        std = [
            2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
            3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
        ]
        self.mean = torch.tensor(mean, dtype=torch.float32)
        self.std = torch.tensor(std, dtype=torch.float32)
        self.z_dim = 16
        self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)

    def forward(
            self,
            z: torch.Tensor,
            *feat_cache: List[torch.Tensor]
    ):
        # from [batch_size, num_frames, num_channels, height, width]
        # to [batch_size, num_channels, num_frames, height, width]
        z = z.permute(0, 2, 1, 3, 4)
        feat_cache = list(feat_cache)
        print("Length of feat_cache: ", len(feat_cache))

        device, dtype = z.device, z.dtype
        scale = [self.mean.to(device=device, dtype=dtype),
                 1.0 / self.std.to(device=device, dtype=dtype)]

        if isinstance(scale[0], torch.Tensor):
            z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
                1, self.z_dim, 1, 1, 1)
        else:
            z = z / scale[1] + scale[0]
        iter_ = z.shape[2]
        x = self.conv2(z)
        for i in range(iter_):
            if i == 0:
                out, feat_cache = self.decoder(
                    x[:, :, i:i + 1, :, :],
                    feat_cache=feat_cache)
            else:
                out_, feat_cache = self.decoder(
                    x[:, :, i:i + 1, :, :],
                    feat_cache=feat_cache)
                out = torch.cat([out, out_], 2)

        out = out.float().clamp_(-1, 1)
        # from [batch_size, num_channels, num_frames, height, width]
        # to [batch_size, num_frames, num_channels, height, width]
        out = out.permute(0, 2, 1, 3, 4)
        return out, feat_cache


class VAEDecoder3d(nn.Module):
    def __init__(self,
                 dim=96,
                 z_dim=16,
                 dim_mult=[1, 2, 4, 4],
                 num_res_blocks=2,
                 attn_scales=[],
                 temperal_upsample=[True, True, False],
                 dropout=0.0):
        super().__init__()
        self.dim = dim
        self.z_dim = z_dim
        self.dim_mult = dim_mult
        self.num_res_blocks = num_res_blocks
        self.attn_scales = attn_scales
        self.temperal_upsample = temperal_upsample
        self.cache_t = 2
        self.decoder_conv_num = 32

        # dimensions
        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
        scale = 1.0 / 2**(len(dim_mult) - 2)

        # init block
        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)

        # middle blocks
        self.middle = nn.Sequential(
            ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
            ResidualBlock(dims[0], dims[0], dropout))

        # upsample blocks
        upsamples = []
        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            # residual (+attention) blocks
            if i == 1 or i == 2 or i == 3:
                in_dim = in_dim // 2
            for _ in range(num_res_blocks + 1):
                upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
                if scale in attn_scales:
                    upsamples.append(AttentionBlock(out_dim))
                in_dim = out_dim

            # upsample block
            if i != len(dim_mult) - 1:
                mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
                upsamples.append(Resample(out_dim, mode=mode))
                scale *= 2.0
        self.upsamples = nn.Sequential(*upsamples)

        # output blocks
        self.head = nn.Sequential(
            RMS_norm(out_dim, images=False), nn.SiLU(),
            CausalConv3d(out_dim, 3, 3, padding=1))

    def forward(
            self,
            x: torch.Tensor,
            feat_cache: List[torch.Tensor]
    ):
        feat_idx = [0]

        # conv1
        idx = feat_idx[0]
        cache_x = x[:, :, -self.cache_t:, :, :].clone()
        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
            # cache last frame of last two chunk
            cache_x = torch.cat([
                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                    cache_x.device), cache_x
            ],
                dim=2)
        x = self.conv1(x, feat_cache[idx])
        feat_cache[idx] = cache_x
        feat_idx[0] += 1

        # middle
        for layer in self.middle:
            if isinstance(layer, ResidualBlock) and feat_cache is not None:
                x = layer(x, feat_cache, feat_idx)
            else:
                x = layer(x)

        # upsamples
        for layer in self.upsamples:
            x = layer(x, feat_cache, feat_idx)

        # head
        for layer in self.head:
            if isinstance(layer, CausalConv3d) and feat_cache is not None:
                idx = feat_idx[0]
                cache_x = x[:, :, -self.cache_t:, :, :].clone()
                if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
                    # cache last frame of last two chunk
                    cache_x = torch.cat([
                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
                            cache_x.device), cache_x
                    ],
                        dim=2)
                x = layer(x, feat_cache[idx])
                feat_cache[idx] = cache_x
                feat_idx[0] += 1
            else:
                x = layer(x)
        return x, feat_cache


================================================
FILE: demo_utils/vae_torch2trt.py
================================================
# ---- INT8 (optional) ----
from demo_utils.vae import (
    VAEDecoderWrapperSingle,                         # main nn.Module
    ZERO_VAE_CACHE           # helper constants shipped with your code base
)
import pycuda.driver as cuda          # ← add
import pycuda.autoinit  # noqa

import sys
from pathlib import Path

import torch
import tensorrt as trt

from utils.dataset import ShardingLMDBDataset

data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=0
)

# ─────────────────────────────────────────────────────────
# 1️⃣  Bring the PyTorch model into scope
#     (all code you pasted lives in `vae_decoder.py`)
# ─────────────────────────────────────────────────────────

# --- dummy tensors (exact shapes you posted) ---
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
dummy_cache_input = [
    torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
    for s in ZERO_VAE_CACHE               # keep exactly the same ordering
]
inputs = [dummy_input, is_first_frame, *dummy_cache_input]

# ─────────────────────────────────────────────────────────
# 2️⃣  Export → ONNX
# ─────────────────────────────────────────────────────────
model = VAEDecoderWrapperSingle().half().cuda().eval()

vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
decoder_state_dict = {}
for key, value in vae_state_dict.items():
    if 'decoder.' in key or 'conv2' in key:
        decoder_state_dict[key] = value
model.load_state_dict(decoder_state_dict)
model = model.half().cuda().eval()                          # only batch dim dynamic

onnx_path = Path("vae_decoder.onnx")
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
all_inputs_names = ["z", "use_cache"] + feat_names

with torch.inference_mode():
    torch.onnx.export(
        model,
        tuple(inputs),                                        # must be a tuple
        onnx_path.as_posix(),
        input_names=all_inputs_names,
        output_names=["rgb_out", "cache_out"],
        opset_version=17,
        do_constant_folding=True,
        dynamo=True
    )
print(f"✅  ONNX graph saved to {onnx_path.resolve()}")

# (Optional) quick sanity-check with ONNX-Runtime
try:
    import onnxruntime as ort
    sess = ort.InferenceSession(onnx_path.as_posix(),
                                providers=["CUDAExecutionProvider"])
    ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
    _ = sess.run(None, ort_inputs)
    print("✅  ONNX graph is executable")
except Exception as e:
    print("⚠️  ONNX check failed:", e)

# ─────────────────────────────────────────────────────────
# 3️⃣  Build the TensorRT engine
# ─────────────────────────────────────────────────────────
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)

with open(onnx_path, "rb") as f:
    if not parser.parse(f.read()):
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        sys.exit("❌  ONNX → TRT parsing failed")

config = builder.create_builder_config()


def set_workspace(config, bytes_):
    """Version-agnostic workspace limit."""
    if hasattr(config, "max_workspace_size"):                # TRT 8 / 9
        config.max_workspace_size = bytes_
    else:                                                    # TRT 10+
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)


# …
config = builder.create_builder_config()
set_workspace(config, 4 << 30)          # 4 GB
# 4 GB

if builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)

# ---- INT8 (optional) ----
# provide a calibrator if you need an INT8 engine; comment this
# block if you only care about FP16.
# ─────────────────────────────────────────────────────────
# helper: version-agnostic workspace limit
# ─────────────────────────────────────────────────────────


def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
    """
    TRT < 10.x  →  config.max_workspace_size
    TRT ≥ 10.x  →  config.set_memory_pool_limit(...)
    """
    if hasattr(config, "max_workspace_size"):                     # TRT 8 / 9
        config.max_workspace_size = bytes_
    else:                                                         # TRT 10+
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
                                     bytes_)

# ─────────────────────────────────────────────────────────
# (optional) INT-8 calibrator
# ─────────────────────────────────────────────────────────
# ‼ Only keep this block if you really need INT-8 ‼                      # gracefully skip if PyCUDA not present


class VAECalibrator(trt.IInt8EntropyCalibrator2):
    def __init__(self, loader, cache="calibration.cache", max_batches=10):
        super().__init__()
        self.loader = iter(loader)
        self.batch_size = loader.batch_size or 1
        self.max_batches = max_batches
        self.count = 0
        self.cache_file = cache
        self.stream = cuda.Stream()
        self.dev_ptrs = {}

    # --- TRT 10 needs BOTH spellings ---
    def get_batch_size(self):
        return self.batch_size

    def getBatchSize(self):
        return self.batch_size

    def get_batch(self, names):
        if self.count >= self.max_batches:
            return None

        # Randomly sample a number from 1 to 10
        import random
        vae_idx = random.randint(0, 10)
        data = next(self.loader)

        latent = data['ode_latent'][0][:, :1]
        is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
        feat_cache = ZERO_VAE_CACHE
        for i in range(vae_idx):
            inputs = [latent, is_first_frame, *feat_cache]
            with torch.inference_mode():
                outputs = model(*inputs)
            latent = data['ode_latent'][0][:, i + 1:i + 2]
            is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
            feat_cache = outputs[1:]

        # -------- ensure context is current --------
        z_np = latent.cpu().numpy().astype('float32')

        ptrs = []                # list[int] – one entry per name
        for name in names:         # <-- match TRT's binding order
            if name == "z":
                arr = z_np
            elif name == "use_cache":
                arr = is_first_frame.cpu().numpy().astype('float32')
            else:
                idx = int(name.split('_')[-1])   # "vae_cache_17" -> 17
                arr = feat_cache[idx].cpu().numpy().astype('float32')

            if name not in self.dev_ptrs:
                self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)

            cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
            ptrs.append(int(self.dev_ptrs[name]))   # ***int() is required***

        self.stream.synchronize()
        self.count += 1
        print(f"Calibration batch {self.count}/{self.max_batches}")
        return ptrs

    # --- calibration-cache helpers (both spellings) ---
    def read_calibration_cache(self):
        try:
            with open(self.cache_file, "rb") as f:
                return f.read()
        except FileNotFoundError:
            return None

    def readCalibrationCache(self):
        return self.read_calibration_cache()

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

    def writeCalibrationCache(self, cache):
        self.write_calibration_cache(cache)


# ─────────────────────────────────────────────────────────
# Builder-config + optimisation profile
# ─────────────────────────────────────────────────────────
config = builder.create_builder_config()
set_workspace(config, 4 << 30)                    # 4 GB

# ► enable FP16 if possible
if builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)

# ► enable INT-8  (delete this block if you don’t need it)
if cuda is not None:
    config.set_flag(trt.BuilderFlag.INT8)
    # supply any representative batch you like – here we reuse the latent z
    calib = VAECalibrator(dataloader)
    # TRT-10 renamed the setter:
    if hasattr(config, "set_int8_calibrator"):    # TRT 10+
        config.set_int8_calibrator(calib)
    else:                                         # TRT ≤ 9
        config.int8_calibrator = calib

# ---- optimisation profile ----
profile = builder.create_optimization_profile()
profile.set_shape(all_inputs_names[0],            # latent z
                  min=(1, 1, 16, 60, 104),
                  opt=(1, 1, 16, 60, 104),
                  max=(1, 1, 16, 60, 104))
profile.set_shape("use_cache",               # scalar flag
                  min=(1,), opt=(1,), max=(1,))
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
    profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)

config.add_optimization_profile(profile)

# ─────────────────────────────────────────────────────────
# Build the engine  (API changed in TRT-10)
# ─────────────────────────────────────────────────────────
print("⚙️  Building engine … (can take a minute)")

if hasattr(builder, "build_serialized_network"):          # TRT 10+
    serialized_engine = builder.build_serialized_network(network, config)
    assert serialized_engine is not None, "build_serialized_network() failed"
    plan_path = Path("checkpoints/vae_decoder_int8.trt")
    plan_path.write_bytes(serialized_engine)
    engine_bytes = serialized_engine                      # keep for smoke-test
else:                                                     # TRT ≤ 9
    engine = builder.build_engine(network, config)
    assert engine is not None, "build_engine() returned None"
    plan_path = Path("checkpoints/vae_decoder_int8.trt")
    plan_path.write_bytes(engine.serialize())
    engine_bytes = engine.serialize()

print(f"✅  TensorRT engine written to {plan_path.resolve()}")

# ─────────────────────────────────────────────────────────
# 4️⃣  Quick smoke-test with the brand-new engine
# ─────────────────────────────────────────────────────────
with trt.Runtime(TRT_LOGGER) as rt:
    engine = rt.deserialize_cuda_engine(engine_bytes)
    context = engine.create_execution_context()
    stream = torch.cuda.current_stream().cuda_stream

    # pre-allocate device buffers once
    device_buffers, outputs = {}, []
    dtype_map = {trt.float32: torch.float32,
                 trt.float16: torch.float16,
                 trt.int8:    torch.int8,
                 trt.int32:   torch.int32}

    for name, tensor in zip(all_inputs_names, inputs):
        if -1 in engine.get_tensor_shape(name):            # dynamic input
            context.set_input_shape(name, tensor.shape)
        context.set_tensor_address(name, int(tensor.data_ptr()))
        device_buffers[name] = tensor

    context.infer_shapes()                                 # propagate ⇢ outputs
    for i in range(engine.num_io_tensors):
        name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
            shape = tuple(context.get_tensor_shape(name))
            dtype = dtype_map[engine.get_tensor_dtype(name)]
            out = torch.empty(shape, dtype=dtype, device="cuda")
            context.set_tensor_address(name, int(out.data_ptr()))
            outputs.append(out)
            print(f"output {name} shape: {shape}")

    context.execute_async_v3(stream_handle=stream)
    torch.cuda.current_stream().synchronize()
    print("✅  TRT execution OK – first output shape:", outputs[0].shape)


================================================
FILE: get_causal_ode_data_chunkwise.py
================================================
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
from utils.scheduler import FlowMatchScheduler
from utils.distributed import launch_distributed_job

import torch.distributed as dist
from tqdm import tqdm
import argparse
import torch
import math
import os
from utils.dataset import LatentLMDBDataset

def init_model(device):
    model = WanDiffusionWrapper(is_causal=True).to(device).to(torch.float32)
    model.model.num_frame_per_block = 3 # !!
    encoder = WanTextEncoder().to(device).to(torch.float32)
    

    scheduler = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
    scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
    scheduler.sigmas = scheduler.sigmas.to(device)

    sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'

    unconditional_dict = encoder(
        text_prompts=[sample_neg_prompt]
    )

    return model, encoder, scheduler, unconditional_dict


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--output_folder", type=str)
    parser.add_argument("--rawdata_path", type=str)
    parser.add_argument("--generator_ckpt", type=str)
    parser.add_argument("--guidance_scale", type=float, default=6.0)


    args = parser.parse_args()

    launch_distributed_job()
    global_rank = dist.get_rank()

    device = torch.cuda.current_device()

    torch.set_grad_enabled(False)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    model, encoder, scheduler, unconditional_dict = init_model(device=device)
    state_dict = torch.load(args.generator_ckpt, map_location="cpu")
        
    gen_sd = state_dict["generator"]
    fixed = {}
    for k, v in gen_sd.items():
        if k.startswith("model._fsdp_wrapped_module."):
            k = k.replace("model._fsdp_wrapped_module.", "", 1)
        if k.startswith("model."):
            k = k.replace("model.", "", 1)
        fixed[k] = v
    state_dict = fixed
    model.model.load_state_dict(
        state_dict, strict=True
    )



    dataset = LatentLMDBDataset(args.rawdata_path)

    if global_rank == 0:
        os.makedirs(args.output_folder, exist_ok=True)
        
    total_steps = int(math.ceil(len(dataset) / dist.get_world_size()))
    for index in tqdm(
        range(total_steps), disable=(dist.get_rank() != 0),
    ):
        prompt_index = index * dist.get_world_size() + dist.get_rank()
        if prompt_index >= len(dataset):
            continue
        sample = dataset[prompt_index]
        prompt = sample["prompts"]
       
        clean_latent = sample["clean_latent"].to(device).unsqueeze(0)
        
        

        conditional_dict = encoder(
            text_prompts=prompt
        )

        latents = torch.randn(
            [1, 21, 16, 60, 104], dtype=torch.float32, device=device
        )
        
        noisy_input = []

        for progress_id, t in enumerate(tqdm(scheduler.timesteps, disable=(dist.get_rank() != 0))):
            timestep = t * \
                torch.ones([1, 21], device=device, dtype=torch.float32)
            noisy_input.append(latents)
            f_cond, x0_pred_cond = model(
                latents, conditional_dict, timestep, clean_x = clean_latent
            )

            f_uncond, x0_pred_uncond = model(
                latents, unconditional_dict, timestep, clean_x = clean_latent
            )

            flow_pred = f_uncond + args.guidance_scale * (
                f_cond - f_uncond
            )
            
            
            latents = scheduler.step(
                flow_pred.flatten(0, 1),
                timestep.flatten(0, 1),
                latents.flatten(0, 1)
            ).unflatten(dim=0, sizes=flow_pred.shape[:2])

        noisy_input.append(latents)
        noisy_input.append(clean_latent)
        
        noisy_inputs = torch.stack(noisy_input, dim=1)

        noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -2, -1]]

        stored_data = noisy_inputs

        torch.save(
            {prompt: stored_data.cpu().detach()},
            os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
        )

    dist.barrier()


if __name__ == "__main__":
    main()


================================================
FILE: get_causal_ode_data_framewise.py
================================================
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
from utils.scheduler import FlowMatchScheduler
from utils.distributed import launch_distributed_job

import torch.distributed as dist
from tqdm import tqdm
import argparse
import torch
import math
import os
from utils.dataset import LatentLMDBDataset

def init_model(device):
    model = WanDiffusionWrapper(is_causal=True).to(device).to(torch.float32)
    model.model.num_frame_per_block = 1 # !!
    encoder = WanTextEncoder().to(device).to(torch.float32)
    

    scheduler = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
    scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
    scheduler.sigmas = scheduler.sigmas.to(device)

    sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'

    unconditional_dict = encoder(
        text_prompts=[sample_neg_prompt]
    )

    return model, encoder, scheduler, unconditional_dict


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--output_folder", type=str)
    parser.add_argument("--rawdata_path", type=str)
    parser.add_argument("--generator_ckpt", type=str)
    parser.add_argument("--guidance_scale", type=float, default=6.0)


    args = parser.parse_args()

    launch_distributed_job()
    global_rank = dist.get_rank()

    device = torch.cuda.current_device()

    torch.set_grad_enabled(False)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    model, encoder, scheduler, unconditional_dict = init_model(device=device)
    state_dict = torch.load(args.generator_ckpt, map_location="cpu")
        
    gen_sd = state_dict["generator"]
    fixed = {}
    for k, v in gen_sd.items():
        if k.startswith("model._fsdp_wrapped_module."):
            k = k.replace("model._fsdp_wrapped_module.", "", 1)
        if k.startswith("model."):
            k = k.replace("model.", "", 1)
        fixed[k] = v
    state_dict = fixed
    model.model.load_state_dict(
        state_dict, strict=True
    )



    dataset = LatentLMDBDataset(args.rawdata_path)

    if global_rank == 0:
        os.makedirs(args.output_folder, exist_ok=True)
        
    total_steps = int(math.ceil(len(dataset) / dist.get_world_size()))
    for index in tqdm(
        range(total_steps), disable=(dist.get_rank() != 0),
    ):
        prompt_index = index * dist.get_world_size() + dist.get_rank()
        if prompt_index >= len(dataset):
            continue
        sample = dataset[prompt_index]
        prompt = sample["prompts"]
       
        clean_latent = sample["clean_latent"].to(device).unsqueeze(0)
        
        

        conditional_dict = encoder(
            text_prompts=prompt
        )

        latents = torch.randn(
            [1, 21, 16, 60, 104], dtype=torch.float32, device=device
        )
        
        noisy_input = []

        for progress_id, t in enumerate(tqdm(scheduler.timesteps, disable=(dist.get_rank() != 0))):
            timestep = t * \
                torch.ones([1, 21], device=device, dtype=torch.float32)
            noisy_input.append(latents)
            f_cond, x0_pred_cond = model(
                latents, conditional_dict, timestep, clean_x = clean_latent
            )

            f_uncond, x0_pred_uncond = model(
                latents, unconditional_dict, timestep, clean_x = clean_latent
            )

            flow_pred = f_uncond + args.guidance_scale * (
                f_cond - f_uncond
            )
            
            
            latents = scheduler.step(
                flow_pred.flatten(0, 1),
                timestep.flatten(0, 1),
                latents.flatten(0, 1)
            ).unflatten(dim=0, sizes=flow_pred.shape[:2])

        noisy_input.append(latents)
        noisy_input.append(clean_latent)
        
        noisy_inputs = torch.stack(noisy_input, dim=1)

        noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -2, -1]]

        stored_data = noisy_inputs

        torch.save(
            {prompt: stored_data.cpu().detach()},
            os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
        )

    dist.barrier()


if __name__ == "__main__":
    main()


================================================
FILE: get_causal_ode_data_kv_optimized.py
================================================
import argparse
import math
import os

import torch
import torch.distributed as dist
from tqdm import tqdm

from utils.dataset import LatentLMDBDataset
from utils.distributed import launch_distributed_job
from utils.ode_generation import (
    CausalODETrajectoryGenerator,
    merge_cfg_prompt_embeds,
)
from utils.scheduler import FlowMatchScheduler
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder


DEFAULT_NEGATIVE_PROMPT = (
    "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
    "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
    "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
    "杂乱的背景,三条腿,背景人很多,倒着走"
)
DEFAULT_NUM_INFERENCE_STEPS = 48
DEFAULT_SCHEDULER_SHIFT = 5.0
DEFAULT_TARGET_NUM_FRAMES = 21
DEFAULT_TRAJECTORY_INDICES = [0, 12, 24, 36, -2, -1]


def normalize_generator_state_dict(state_dict: dict) -> dict:
    if "generator" in state_dict:
        state_dict = state_dict["generator"]
    elif "generator_ema" in state_dict:
        state_dict = state_dict["generator_ema"]

    fixed = {}
    for k, v in state_dict.items():
        if k.startswith("model._fsdp_wrapped_module."):
            k = k.replace("model._fsdp_wrapped_module.", "", 1)
        if k.startswith("model."):
            k = k.replace("model.", "", 1)
        fixed[k] = v
    return fixed


def init_model(
    device,
    num_frame_per_block: int,
    scheduler_shift: float,
    num_inference_steps: int,
    negative_prompt: str,
):
    model = WanDiffusionWrapper(is_causal=True).to(device).to(torch.float32)
    model.model.num_frame_per_block = num_frame_per_block
    encoder = WanTextEncoder().to(device).to(torch.float32)

    scheduler = FlowMatchScheduler(
        shift=scheduler_shift,
        sigma_min=0.0,
        extra_one_step=True,
    )
    scheduler.set_timesteps(
        num_inference_steps=num_inference_steps,
        denoising_strength=1.0,
    )
    scheduler.sigmas = scheduler.sigmas.to(device)

    unconditional_dict = encoder(text_prompts=[negative_prompt])
    return model, encoder, scheduler, unconditional_dict


def prepare_clean_latent(
    sample: dict,
    target_num_frames: int | None,
    device,
) -> torch.Tensor:
    clean_latent = sample["clean_latent"].to(device).unsqueeze(0)
    if target_num_frames is None:
        return clean_latent

    if clean_latent.shape[1] < target_num_frames:
        raise ValueError(
            "clean_latent has fewer frames than requested: "
            f"{clean_latent.shape[1]} < {target_num_frames}"
        )
    if clean_latent.shape[1] != target_num_frames:
        clean_latent = clean_latent[:, :target_num_frames, ...]
    return clean_latent.contiguous()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("--rawdata_path", type=str, required=True)
    parser.add_argument("--generator_ckpt", type=str, required=True)
    parser.add_argument("--num_frames_per_chunk", type=int, required=True)
    parser.add_argument("--guidance_scale", type=float, default=6.0)
    parser.add_argument(
        "--generation_mode",
        type=str,
        default="full",
        choices=["full", "blockwise_kv"],
    )

    args = parser.parse_args()

    launch_distributed_job()
    global_rank = dist.get_rank()

    device = torch.cuda.current_device()

    torch.set_grad_enabled(False)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    model, encoder, scheduler, unconditional_dict = init_model(
        device=device,
        num_frame_per_block=args.num_frame_per_block,
        scheduler_shift=DEFAULT_SCHEDULER_SHIFT,
        num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
        negative_prompt=DEFAULT_NEGATIVE_PROMPT,
    )
    state_dict = torch.load(args.generator_ckpt, map_location="cpu")
    model.model.load_state_dict(
        normalize_generator_state_dict(state_dict),
        strict=True,
    )

    dataset = LatentLMDBDataset(
        args.rawdata_path,
        max_pair=int(1e8),
    )

    if global_rank == 0:
        os.makedirs(args.output_folder, exist_ok=True)

    trajectory_generator = CausalODETrajectoryGenerator(
        model=model,
        scheduler=scheduler,
        num_frame_per_block=args.num_frame_per_block,
        num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
        guidance_scale=args.guidance_scale,
    )

    total_steps = int(math.ceil(len(dataset) / dist.get_world_size()))
    for index in tqdm(
        range(total_steps), disable=(global_rank != 0),
    ):
        prompt_index = index * dist.get_world_size() + global_rank
        if prompt_index >= len(dataset):
            continue

        output_path = os.path.join(args.output_folder, f"{prompt_index:05d}.pt")

        sample = dataset[prompt_index]
        prompt = sample["prompts"]
        clean_latent = prepare_clean_latent(
            sample=sample,
            target_num_frames=DEFAULT_TARGET_NUM_FRAMES,
            device=device,
        )

        conditional_dict = encoder(text_prompts=[prompt])
        paired_conditional_dict = merge_cfg_prompt_embeds(
            conditional_dict=conditional_dict,
            unconditional_dict=unconditional_dict,
        )
        initial_noise = torch.randn_like(clean_latent, dtype=torch.float32)
        stored_data = trajectory_generator.generate(
            clean_latent=clean_latent,
            paired_conditional_dict=paired_conditional_dict,
            trajectory_indices=DEFAULT_TRAJECTORY_INDICES,
            generation_mode=args.generation_mode,
            initial_noise=initial_noise,
        )

        torch.save(
            {prompt: stored_data.cpu().detach()},
            output_path,
        )

    dist.barrier()


if __name__ == "__main__":
    main()


================================================
FILE: inference.py
================================================
import argparse
import argparse
import torch
import os
from omegaconf import OmegaConf
from tqdm import tqdm
from torchvision import transforms
from torchvision.io import write_video
from einops import rearrange
import torch.distributed as dist
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import json

from pipeline import (
    CausalDiffusionInferencePipeline,
    CausalInferencePipeline,
)
from utils.dataset import TextDataset, TextImagePairDataset
from utils.misc import set_seed

from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller

parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to the config file")
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--output_folder", type=str, help="Output folder")
parser.add_argument("--num_output_frames", type=int, default=21, help="Number of overlap frames between sliding windows")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
parser.add_argument("--report_timing", action="store_true",
                    help="Only tested on A800, for the Causal Forcing++ latency. Not make claims for other hardware like H100. For the result on H100, refer to the reported results in the Self Forcing paper.")
args = parser.parse_args()

# Initialize distributed inference
if "LOCAL_RANK" in os.environ:
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    world_size = dist.get_world_size()
    
else:
    device = torch.device("cuda")
    local_rank = 0
    world_size = 1

set_seed(args.seed)

print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
low_memory = get_cuda_free_memory_gb(gpu) < 40

torch.set_grad_enabled(False)

config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)

# Initialize pipeline
if hasattr(config, 'denoising_step_list'):
    # Few-step inference
    pipeline = CausalInferencePipeline(config, device=device)
else:
    # Multi-step diffusion inference
    pipeline = CausalDiffusionInferencePipeline(config, device=device)

if args.checkpoint_path:
    state_dict = torch.load(args.checkpoint_path, map_location="cpu")
    key = 'generator_ema' if args.use_ema else 'generator'
    gen_sd = state_dict[key]

    try:
        pipeline.generator.load_state_dict(gen_sd)
    except RuntimeError:
        fixed = {}
        for k, v in gen_sd.items():
            if k.startswith("model._fsdp_wrapped_module."):
                k = k.replace("model._fsdp_wrapped_module.", "model.", 1)
            fixed[k] = v
        pipeline.generator.load_state_dict(fixed, strict=False)

pipeline = pipeline.to(dtype=torch.bfloat16)
if low_memory:
    DynamicSwapInstaller.install_model(pipeline.text_encoder, device=gpu)
else:
    pipeline.text_encoder.to(device=gpu)
pipeline.generator.to(device=gpu)
pipeline.vae.to(device=gpu)


# Create dataset
if args.i2v:
    assert not dist.is_initialized(), "I2V does not support distributed inference yet"
    transform = transforms.Compose([
        transforms.Resize((480, 832)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataset = TextImagePairDataset(args.data_path, transform=transform)
else:
    dataset = TextDataset(prompt_path=args.data_path)
num_prompts = len(dataset)
print(f"Number of prompts: {num_prompts}")

if args.report_timing and num_prompts < 2:
    print(f"[WARN] --report_timing requires at least 2 prompts "
          f"(got {num_prompts}); timing disabled.")
    args.report_timing = False

if dist.is_initialized():
    sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
    sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)

# Create output directory (only on main process to avoid race conditions)
if local_rank == 0:
    os.makedirs(args.output_folder, exist_ok=True)

if dist.is_initialized():
    dist.barrier()

def encode(self, videos: torch.Tensor) -> torch.Tensor:
    device, dtype = videos[0].device, videos[0].dtype
    scale = [self.mean.to(device=device, dtype=dtype),
             1.0 / self.std.to(device=device, dtype=dtype)]
    output = [
        self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
        for u in videos
    ]

    output = torch.stack(output, dim=0)
    return output


for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
    idx = batch_data['idx'].item()

    if isinstance(batch_data, dict):
        batch = batch_data
    elif isinstance(batch_data, list):
        batch = batch_data[0]  # First (and only) item in the batch

    all_video = []
    num_generated_frames = 0  # Number of generated (latent) frames
    
    
    if args.i2v:
        assert config.num_frame_per_block == 1, "Current I2V only supports the frame-wise model."
        # For image-to-video, batch contains image and caption
        prompt = batch['prompts'][0]  # Get caption from batch
        output_path = os.path.join(args.output_folder, f'{prompt[:100]}.mp4')
        if os.path.exists(output_path):
            print('Video has been generated. Pass!')
            continue
        # Process the image
        image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)

        # Encode the input image as the first latent
        initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
        prompts = [prompt] 
        sampled_noise = torch.randn(
            [1, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
        )
    else:
        # For text-to-video, batch is just the text prompt
        prompt = batch['prompts'][0]
        output_path = os.path.join(args.output_folder, f'{prompt[:100]}.mp4')
        if os.path.exists(output_path):
            print('Video has been generated. Pass!')
            continue
        extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
        if extended_prompt is not None:
            prompts = [extended_prompt] 
        else:
            prompts = [prompt] 

        initial_latent = None
        sampled_noise = torch.randn(
            [1, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
        )

    sample_report_timing = args.report_timing and i >= 1
    video, latents = pipeline.inference(
        noise=sampled_noise,
        text_prompts=prompts,
        return_latents=True,
        initial_latent=initial_latent,
        report_timing=sample_report_timing,
    )
    if sample_report_timing:
        latency = pipeline.first_chunk_time
        elapsed = pipeline.last_generation_time
        num_pixel_frames = video.shape[1]
        fps = num_pixel_frames / elapsed if elapsed > 0 else float('inf')
        print(f"[Sample {i}] {num_pixel_frames} frames, "
              f"latency ↓ {latency:.2f}s, FPS ↑ {fps:.2f}")
        # Only tested on A800, for the Causal Forcing++ paper latency & throughput.
        # Not make claims for other hardware like H100.
        # For the result on H100, refer to the reported results in the Self Forcing paper.
        # We do not guarantee that our FPS/latency measurement protocol is identical to that used in the Self Forcing paper.
    current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
    all_video.append(current_video)
    num_generated_frames += latents.shape[1]

    # Final output video
    clean_latent = latents[0].cpu() 
    video = 255.0 * torch.cat(all_video, dim=1)

    # Clear VAE cache
    pipeline.vae.model.clear_cache()

    output_path = os.path.join(args.output_folder, f'{prompt[:100]}.mp4')
    write_video(output_path, video[0], fps=16)

       


================================================
FILE: long_video/LICENSE
================================================
Tencent is pleased to support the community by making RollingForcing available.

Copyright (C)  2025 Tencent.  All rights reserved. 

The open-source software and/or models included in this distribution may have been modified by Tencent (“Tencent Modifications”). All Tencent Modifications are Copyright (C) Tencent.

RollingForcing is licensed under the License Terms of RollingForcing, except for the third-party components listed below, which remain licensed under their respective original terms. RollingForcing does not impose any additional restrictions beyond those specified in the original licenses of these third-party components. Users are required to comply with all applicable terms and conditions of the original licenses and to ensure that the use of these third-party components conforms to all relevant laws and regulations.

For the avoidance of doubt, RollingForcing refers solely to training code, inference code, parameters, and weights made publicly available by Tencent in accordance with the License Terms of RollingForcing. 

Terms of the License Terms of RollingForcing:
--------------------------------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and /or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

- You agree to use RollingForcing only for academic purposes, and refrain from using it for any commercial or production purposes under any circumstances.

- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.



Dependencies and Licenses:

This open-source project, RollingForcing, builds upon the following open-source models and/or software components, each of which remains licensed under its original license. Certain models or software may include modifications made by Tencent (“Tencent Modifications”), which are Copyright (C) Tencent.

In case you believe there have been errors in the attribution below, you may submit the concerns to us for review and correction.

Open Source Model Licensed under the  Apache-2.0:
--------------------------------------------------------------------
1. Wan-AI/Wan2.1-T2V-1.3B
Copyright (c) 2025 Wan Team

Terms of the Apache-2.0:
--------------------------------------------------------------------
Apache License
                       Version 2.0, January 2004
                    http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

Definitions.

"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.

"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.

"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.

"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.

"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.

"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.

"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).

"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.

"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."

"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.

Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.

Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.

Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:

(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.

Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.

Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.

Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.

Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.

Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS

================================================
FILE: long_video/README.md
================================================
Builing on [Rolling Forcing](https://github.com/TencentARC/RollingForcing), we implemented minute-level long video generation.

### Installation

```bash
conda activate causal_forcing
pip install tensorboard opencv-python packaging
```

### CLI inference
```
hf download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
hf download zhuhz22/Causal-Forcing chunkwise/longvideo.pt --local-dir ../checkpoints

python inference.py \
    --config_path configs/rolling_forcing_dmd.yaml \
    --output_folder videos/rolling_forcing_dmd \
    --checkpoint_path ../checkpoints/chunkwise/longvideo.pt \
    --data_path prompts/example_prompts.txt \
    --num_output_frames 252 \
    --use_ema
```

### Training
```
hf download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
hf download Wan-AI/Wan2.1-T2V-14B --local-dir wan_models/Wan2.1-T2V-14B

torchrun --nproc_per_node=8 \
  --rdzv_backend=c10d \
  --rdzv_endpoint 127.0.0.1:29500 \
  train.py \
  -- \
  --config_path configs/rolling_forcing_dmd.yaml \
  --logdir logs/rolling_forcing_dmd
```
> We recommend training for 3000 steps.

### Acknowledge
We adopt [Rolling Forcing](https://github.com/TencentARC/RollingForcing) as our long video generation framework and only change the ODE initialization part.


================================================
FILE: long_video/app.py
================================================
import os
import argparse
import time
from typing import Optional

import torch
from torchvision.io import write_video
from omegaconf import OmegaConf
from einops import rearrange
import gradio as gr

from pipeline import CausalDiffusionInferencePipeline, CausalInferencePipeline


# -----------------------------
# Globals (loaded once per process)
# -----------------------------
_PIPELINE: Optional[torch.nn.Module] = None
_DEVICE: Optional[torch.device] = None


def _ensure_gpu():
    if not torch.cuda.is_available():
        raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
    # Bind to GPU:0 by default
    torch.cuda.set_device(0)


def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
    global _PIPELINE, _DEVICE
    if _PIPELINE is not None:
        return _PIPELINE

    _ensure_gpu()
    _DEVICE = torch.device("cuda:0")

    # Load and merge configs
    config = OmegaConf.load(config_path)
    default_config = OmegaConf.load("configs/default_config.yaml")
    config = OmegaConf.merge(default_config, config)

    # Choose pipeline type based on config
    pipeline = CausalInferencePipeline(config, device=_DEVICE)


    # Load checkpoint if provided
    if checkpoint_path and os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        if use_ema and 'generator_ema' in state_dict:
            state_dict_to_load = state_dict['generator_ema']
            # Remove possible FSDP prefix
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict_to_load.items():
                new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
            state_dict_to_load = new_state_dict
        else:
            state_dict_to_load = state_dict.get('generator', state_dict)
        pipeline.generator.load_state_dict(state_dict_to_load, strict=False)

    # The codebase assumes bfloat16 on GPU
    pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
    pipeline.eval()

    # Quick sanity path check for Wan models to give friendly errors
    wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
    if not os.path.isdir(wan_dir):
        raise gr.Error(
            "Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
            "Please download it first, e.g.:\n"
            "huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
        )

    _PIPELINE = pipeline
    return _PIPELINE


def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
    os.makedirs(output_dir, exist_ok=True)

    def predict(prompt: str, num_frames: int) -> str:
        if not prompt or not prompt.strip():
            raise gr.Error("Please enter a non-empty text prompt.")

        num_frames = int(num_frames)
        if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
            raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")

        pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)

        # Prepare inputs
        prompts = [prompt.strip()]
        noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)

        torch.set_grad_enabled(False)
        with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
            video = pipeline.inference_rolling_forcing(
                noise=noise,
                text_prompts=prompts,
                return_latents=False,
                initial_latent=None,
            )

        # video: [B=1, T, C, H, W] in [0,1]
        video = rearrange(video, 'b t c h w -> b t h w c')[0]
        video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()

        # Save to a unique filepath
        safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
        ts = int(time.time())
        filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
        write_video(filepath, video_uint8, fps=16)
        print(f"Saved generated video to {filepath}")

        return filepath

    return predict


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
                        help='Path to the model config')
    parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
                        help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
    parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
    parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
    parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Gradio server host')
    parser.add_argument('--server_port', type=int, default=7860, help='Gradio server port')
    args = parser.parse_args()

    predict = build_predict(
        config_path=args.config_path,
        checkpoint_path=args.checkpoint_path,
        output_dir=args.output_dir,
        use_ema=not args.no_ema,
    )

    demo = gr.Interface(
        fn=predict,
        inputs=[
            gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
            gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
        ],
        outputs=gr.Video(label="Generated Video", format="mp4"),
        title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
        description=(
            "Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
            "**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
            "\n"
            "If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
            "You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
        ),
        allow_flagging='never',
    )

    try:
        # Gradio <= 3.x
        demo.queue(concurrency_count=1, max_size=2)
    except TypeError:
        # Gradio >= 4.x
        demo.queue(max_size=2)
    demo.launch(server_name=args.server_name, server_port=args.server_port, show_error=True)


if __name__ == "__main__":
    main()


================================================
FILE: long_video/configs/default_config.yaml
================================================
independent_first_frame: false
warp_denoising_step: false
weight_decay: 0.01
same_step_across_blocks: true
discriminator_lr_multiplier: 1.0
last_step_only: false
i2v: false
num_training_frames: 27
gc_interval: 100
context_noise: 0
causal: true

ckpt_step: 0
prompt_name: MovieGenVideoBench
prompt_path: prompts/MovieGenVideoBench.txt
eval_first_n: 64
num_samples: 1
height: 480
width: 832
num_frames: 81

================================================
FILE: long_video/configs/rolling_forcing_dmd.yaml
================================================
generator_ckpt: ../checkpoints/chunkwise/causal_ode.pt
generator_fsdp_wrap_strategy: size
real_score_fsdp_wrap_strategy: size
fake_score_fsdp_wrap_strategy: size
real_name: Wan2.1-T2V-14B
text_encoder_fsdp_wrap_strategy: size
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
ts_schedule: false
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
denoising_loss_type: flow
mixed_precision: true
seed: 0
sharding_strategy: hybrid_full
lr: 1.5e-06
lr_critic: 4.0e-07
beta1: 0.0
beta2: 0.999
beta1_critic: 0.0
beta2_critic: 0.999
data_path: ../prompts/vidprom_filtered_extended.txt 
batch_size: 1
ema_weight: 0.99
ema_start_step: 200
total_batch_size: 64
log_iters: 1000
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
dfake_gen_update_ratio: 5
image_or_video_shape:
- 1
- 21
- 16
- 60
- 104
distribution_loss: dmd
trainer: score_distillation
gradient_checkpointing: true
num_frame_per_block: 3
load_raw_video: false
model_kwargs:
  timestep_shift: 5.0

================================================
FILE: long_video/inference.py
================================================
import argparse
import torch
import os
from omegaconf import OmegaConf
from collections import OrderedDict
from tqdm import tqdm
from torchvision import transforms
from torchvision.io import write_video
from einops import rearrange
import torch.distributed as dist
import imageio
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pipeline import (
    CausalDiffusionInferencePipeline,
    CausalInferencePipeline
)
from utils.dataset import TextDataset, TextImagePairDataset
from utils.misc import set_seed

parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to the config file")
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
parser.add_argument("--output_folder", type=str, help="Output folder")
parser.add_argument("--num_output_frames", type=int, default=21,
                    help="Number of overlap frames between sliding windows")
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
parser.add_argument("--save_with_index", action="store_true",
                    help="Whether to save the video using the index or prompt as the filename")
args = parser.parse_args()

# Initialize distributed inference
if "LOCAL_RANK" in os.environ:
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    world_size = dist.get_world_size()
    set_seed(args.seed + local_rank)
else:
    device = torch.device("cuda")
    local_rank = 0
    world_size = 1
    set_seed(args.seed)

torch.set_grad_enabled(False)

config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)

# Initialize pipeline
if hasattr(config, 'denoising_step_list'):
    # Few-step inference
    pipeline = CausalInferencePipeline(config, device=device)
else:
    # Multi-step diffusion inference
    pipeline = CausalDiffusionInferencePipeline(config, device=device)

if args.checkpoint_path:
    state_dict = torch.load(args.checkpoint_path, map_location="cpu")
    if args.use_ema:
        state_dict_to_load = state_dict['generator_ema']
        def remove_fsdp_prefix(state_dict):
            new_state_dict = OrderedDict()
            for key, value in state_dict.items():
                if "_fsdp_wrapped_module." in key:
                    new_key = key.replace("_fsdp_wrapped_module.", "")
                    new_state_dict[new_key] = value
                else:
                    new_state_dict[key] = value
            return new_state_dict
        state_dict_to_load = remove_fsdp_prefix(state_dict_to_load)
    else:
        state_dict_to_load = state_dict['generator']
    pipeline.generator.load_state_dict(state_dict_to_load)

pipeline = pipeline.to(device=device, dtype=torch.bfloat16)

# Create dataset
if args.i2v:
    assert not dist.is_initialized(), "I2V does not support distributed inference yet"
    transform = transforms.Compose([
        transforms.Resize((480, 832)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataset = TextImagePairDataset(args.data_path, transform=transform)
else:
    dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
num_prompts = len(dataset)
print(f"Number of prompts: {num_prompts}")

if dist.is_initialized():
    sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
    sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)

# Create output directory (only on main process to avoid race conditions)
if local_rank == 0:
    os.makedirs(args.output_folder, exist_ok=True)

if dist.is_initialized():
    dist.barrier()


def encode(self, videos: torch.Tensor) -> torch.Tensor:
    device, dtype = videos[0].device, videos[0].dtype
    scale = [self.mean.to(device=device, dtype=dtype),
             1.0 / self.std.to(device=device, dtype=dtype)]
    output = [
        self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
        for u in videos
    ]

    output = torch.stack(output, dim=0)
    return output


for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
    idx = batch_data['idx'].item()

    # For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
    # Unpack the batch data for convenience
    if isinstance(batch_data, dict):
        batch = batch_data
    elif isinstance(batch_data, list):
        batch = batch_data[0]  # First (and only) item in the batch

    all_video = []
    num_generated_frames = 0  # Number of generated (latent) frames

    if args.i2v:
        # For image-to-video, batch contains image and caption
        prompt = batch['prompts'][0]  # Get caption from batch
        prompts = [prompt] * args.num_samples

        # Process the image
        image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=
Download .txt
gitextract_cbc4vx5e/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── ar_diffusion_tf_chunkwise.yaml
│   ├── ar_diffusion_tf_framewise.yaml
│   ├── causal_cd_chunkwise.yaml
│   ├── causal_cd_framewise.yaml
│   ├── causal_forcing_dmd_chunkwise.yaml
│   ├── causal_forcing_dmd_framewise.yaml
│   ├── causal_forcing_dmd_framewise_1step.yaml
│   ├── causal_forcing_dmd_framewise_2step.yaml
│   ├── causal_ode_chunkwise.yaml
│   ├── causal_ode_framewise.yaml
│   └── default_config.yaml
├── demo.py
├── demo_utils/
│   ├── constant.py
│   ├── memory.py
│   ├── taehv.py
│   ├── utils.py
│   ├── vae.py
│   ├── vae_block3.py
│   └── vae_torch2trt.py
├── get_causal_ode_data_chunkwise.py
├── get_causal_ode_data_framewise.py
├── get_causal_ode_data_kv_optimized.py
├── inference.py
├── long_video/
│   ├── LICENSE
│   ├── README.md
│   ├── app.py
│   ├── configs/
│   │   ├── default_config.yaml
│   │   └── rolling_forcing_dmd.yaml
│   ├── inference.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── causvid.py
│   │   ├── diffusion.py
│   │   ├── dmd.py
│   │   ├── gan.py
│   │   ├── ode_regression.py
│   │   └── sid.py
│   ├── pipeline/
│   │   ├── __init__.py
│   │   ├── bidirectional_diffusion_inference.py
│   │   ├── bidirectional_inference.py
│   │   ├── causal_diffusion_inference.py
│   │   ├── rolling_forcing_inference.py
│   │   └── rolling_forcing_training.py
│   ├── prompts/
│   │   └── example_prompts.txt
│   ├── requirements.txt
│   ├── train.py
│   ├── trainer/
│   │   ├── __init__.py
│   │   ├── diffusion.py
│   │   ├── distillation.py
│   │   ├── gan.py
│   │   └── ode.py
│   ├── utils/
│   │   ├── dataset.py
│   │   ├── distributed.py
│   │   ├── lmdb.py
│   │   ├── loss.py
│   │   ├── misc.py
│   │   ├── scheduler.py
│   │   └── wan_wrapper.py
│   └── wan/
│       ├── README.md
│       ├── __init__.py
│       ├── configs/
│       │   ├── __init__.py
│       │   ├── shared_config.py
│       │   ├── wan_i2v_14B.py
│       │   ├── wan_t2v_14B.py
│       │   └── wan_t2v_1_3B.py
│       ├── distributed/
│       │   ├── __init__.py
│       │   ├── fsdp.py
│       │   └── xdit_context_parallel.py
│       ├── image2video.py
│       ├── modules/
│       │   ├── __init__.py
│       │   ├── attention.py
│       │   ├── causal_model.py
│       │   ├── clip.py
│       │   ├── model.py
│       │   ├── t5.py
│       │   ├── tokenizers.py
│       │   ├── vae.py
│       │   └── xlm_roberta.py
│       ├── text2video.py
│       └── utils/
│           ├── __init__.py
│           ├── fm_solvers.py
│           ├── fm_solvers_unipc.py
│           ├── prompt_extend.py
│           ├── qwen_vl_utils.py
│           └── utils.py
├── model/
│   ├── __init__.py
│   ├── base.py
│   ├── causvid.py
│   ├── diffusion.py
│   ├── dmd.py
│   ├── gan.py
│   ├── naive_consistency.py
│   ├── ode_regression.py
│   └── sid.py
├── pipeline/
│   ├── __init__.py
│   ├── bidirectional_diffusion_inference.py
│   ├── bidirectional_inference.py
│   ├── bidirectional_training.py
│   ├── causal_diffusion_inference.py
│   ├── causal_inference.py
│   ├── self_forcing_training.py
│   └── teacher_forcing_training.py
├── prompts/
│   ├── demos.txt
│   └── i2v/
│       └── target_crop_info_26-15.json
├── requirements.txt
├── setup.py
├── train.py
├── trainer/
│   ├── __init__.py
│   ├── diffusion.py
│   ├── distillation.py
│   ├── gan.py
│   ├── naive_cd.py
│   └── ode.py
├── utils/
│   ├── create_lmdb_iterative.py
│   ├── dataset.py
│   ├── distributed.py
│   ├── lmdb_.py
│   ├── loss.py
│   ├── merge_and_get_clean.py
│   ├── merge_lmdb.py
│   ├── misc.py
│   ├── ode_generation.py
│   ├── scheduler.py
│   └── wan_wrapper.py
└── wan/
    ├── README.md
    ├── __init__.py
    ├── configs/
    │   ├── __init__.py
    │   ├── shared_config.py
    │   ├── wan_i2v_14B.py
    │   ├── wan_t2v_14B.py
    │   └── wan_t2v_1_3B.py
    ├── distributed/
    │   ├── __init__.py
    │   ├── fsdp.py
    │   └── xdit_context_parallel.py
    ├── image2video.py
    ├── modules/
    │   ├── __init__.py
    │   ├── attention.py
    │   ├── causal_model.py
    │   ├── clip.py
    │   ├── model.py
    │   ├── t5.py
    │   ├── tokenizers.py
    │   ├── vae.py
    │   └── xlm_roberta.py
    ├── text2video.py
    └── utils/
        ├── __init__.py
        ├── fm_solvers.py
        ├── fm_solvers_unipc.py
        ├── prompt_extend.py
        ├── qwen_vl_utils.py
        └── utils.py
Download .txt
SYMBOL INDEX (1104 symbols across 103 files)

FILE: demo.py
  function initialize_vae_decoder (line 60) | def initialize_vae_decoder(use_taehv=False, use_trt=False):
  function tensor_to_base64_frame (line 162) | def tensor_to_base64_frame(frame_tensor):
  function frame_sender_worker (line 190) | def frame_sender_worker():
  function generate_video_stream (line 242) | def generate_video_stream(prompt, seed, enable_torch_compile=False, enab...
  function generate_mp4_from_images (line 509) | def generate_mp4_from_images(image_directory, output_video_path, fps=24):
  function calculate_sha256 (line 533) | def calculate_sha256(data):
  function handle_connect (line 543) | def handle_connect():
  function handle_disconnect (line 549) | def handle_disconnect():
  function handle_start_generation (line 554) | def handle_start_generation(data):
  function handle_stop_generation (line 597) | def handle_stop_generation():
  function index (line 614) | def index():
  function api_status (line 619) | def api_status():

FILE: demo_utils/memory.py
  class DynamicSwapInstaller (line 13) | class DynamicSwapInstaller:
    method _install_module (line 15) | def _install_module(module: torch.nn.Module, **kwargs):
    method _uninstall_module (line 43) | def _uninstall_module(module: torch.nn.Module):
    method install_model (line 49) | def install_model(model: torch.nn.Module, **kwargs):
    method uninstall_model (line 55) | def uninstall_model(model: torch.nn.Module):
  function fake_diffusers_current_device (line 61) | def fake_diffusers_current_device(model: torch.nn.Module, target_device:...
  function get_cuda_free_memory_gb (line 72) | def get_cuda_free_memory_gb(device=None):
  function move_model_to_device_with_memory_preservation (line 85) | def move_model_to_device_with_memory_preservation(model, target_device, ...
  function offload_model_from_device_for_memory_preservation (line 101) | def offload_model_from_device_for_memory_preservation(model, target_devi...
  function unload_complete_models (line 117) | def unload_complete_models(*args):
  function load_model_as_complete (line 127) | def load_model_as_complete(model, target_device, unload=True):

FILE: demo_utils/taehv.py
  function conv (line 16) | def conv(n_in, n_out, **kwargs):
  class Clamp (line 20) | class Clamp(nn.Module):
    method forward (line 21) | def forward(self, x):
  class MemBlock (line 25) | class MemBlock(nn.Module):
    method __init__ (line 26) | def __init__(self, n_in, n_out):
    method forward (line 33) | def forward(self, x, past):
  class TPool (line 37) | class TPool(nn.Module):
    method __init__ (line 38) | def __init__(self, n_f, stride):
    method forward (line 43) | def forward(self, x):
  class TGrow (line 48) | class TGrow(nn.Module):
    method __init__ (line 49) | def __init__(self, n_f, stride):
    method forward (line 54) | def forward(self, x):
  function apply_model_with_memblocks (line 60) | def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
  class TAEHV (line 159) | class TAEHV(nn.Module):
    method __init__ (line 163) | def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(...
    method patch_tgrow_layers (line 195) | def patch_tgrow_layers(self, sd):
    method encode_video (line 210) | def encode_video(self, x, parallel=True, show_progress_bar=True):
    method decode_video (line 222) | def decode_video(self, x, parallel=True, show_progress_bar=False):
    method forward (line 236) | def forward(self, x):
  function main (line 241) | def main():

FILE: demo_utils/utils.py
  function min_resize (line 19) | def min_resize(x, m):
  function d_resize (line 36) | def d_resize(x, y):
  function resize_and_center_crop (line 48) | def resize_and_center_crop(image, target_width, target_height):
  function resize_and_center_crop_pytorch (line 66) | def resize_and_center_crop_pytorch(image, target_width, target_height):
  function resize_without_crop (line 85) | def resize_without_crop(image, target_width, target_height):
  function just_crop (line 94) | def just_crop(image, w, h):
  function write_to_json (line 108) | def write_to_json(data, file_path):
  function read_from_json (line 116) | def read_from_json(file_path):
  function get_active_parameters (line 122) | def get_active_parameters(m):
  function cast_training_params (line 126) | def cast_training_params(m, dtype=torch.float32):
  function separate_lora_AB (line 135) | def separate_lora_AB(parameters, B_patterns=None):
  function set_attr_recursive (line 151) | def set_attr_recursive(obj, attr, value):
  function print_tensor_list_size (line 159) | def print_tensor_list_size(tensors):
  function batch_mixture (line 180) | def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
  function zero_module (line 196) | def zero_module(module):
  function supress_lower_channels (line 203) | def supress_lower_channels(m, k, alpha=0.01):
  function freeze_module (line 213) | def freeze_module(m):
  function get_latest_safetensors (line 221) | def get_latest_safetensors(folder_path):
  function generate_random_prompt_from_tags (line 232) | def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=...
  function interpolate_numbers (line 239) | def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
  function uniform_random_by_intervals (line 246) | def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=Fa...
  function soft_append_bcthw (line 255) | def soft_append_bcthw(history, current, overlap=0):
  function save_bcthw_as_mp4 (line 269) | def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
  function save_bcthw_as_png (line 286) | def save_bcthw_as_png(x, output_filename):
  function save_bchw_as_png (line 295) | def save_bchw_as_png(x, output_filename):
  function add_tensors_with_padding (line 304) | def add_tensors_with_padding(tensor1, tensor2):
  function print_free_mem (line 323) | def print_free_mem():
  function print_gpu_parameters (line 333) | def print_gpu_parameters(device, state_dict, log_count=1):
  function visualize_txt_as_img (line 348) | def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans...
  function blue_mark (line 386) | def blue_mark(x):
  function green_mark (line 394) | def green_mark(x):
  function frame_mark (line 401) | def frame_mark(x):
  function pytorch2numpy (line 411) | def pytorch2numpy(imgs):
  function numpy2pytorch (line 422) | def numpy2pytorch(imgs):
  function duplicate_prefix_to_suffix (line 429) | def duplicate_prefix_to_suffix(x, count, zero_out=False):
  function weighted_mse (line 436) | def weighted_mse(a, b, weight):
  function clamped_linear_interpolation (line 440) | def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
  function expand_to_dims (line 447) | def expand_to_dims(x, target_dims):
  function repeat_to_batch_size (line 451) | def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
  function dim5 (line 468) | def dim5(x):
  function dim4 (line 472) | def dim4(x):
  function dim3 (line 476) | def dim3(x):
  function crop_or_pad_yield_mask (line 480) | def crop_or_pad_yield_mask(x, length):
  function extend_dim (line 495) | def extend_dim(x, dim, minimal_length, zero_pad=False):
  function lazy_positional_encoding (line 513) | def lazy_positional_encoding(t, repeats=None):
  function state_dict_offset_merge (line 530) | def state_dict_offset_merge(A, B, C=None):
  function state_dict_weighted_merge (line 547) | def state_dict_weighted_merge(state_dicts, weights):
  function group_files_by_folder (line 574) | def group_files_by_folder(all_files):
  function generate_timestamp (line 587) | def generate_timestamp():
  function write_PIL_image_with_png_info (line 595) | def write_PIL_image_with_png_info(image, metadata, path):
  function torch_safe_save (line 606) | def torch_safe_save(content, path):
  function move_optimizer_to_device (line 612) | def move_optimizer_to_device(optimizer, device):

FILE: demo_utils/vae.py
  class ResidualBlock (line 13) | class ResidualBlock(nn.Module):
    method __init__ (line 15) | def __init__(self, in_dim, out_dim, dropout=0.0):
    method forward (line 29) | def forward(self, x, feat_cache_1, feat_cache_2):
  class Resample (line 51) | class Resample(nn.Module):
    method __init__ (line 53) | def __init__(self, dim, mode):
    method forward (line 73) | def forward(self, x, is_first_frame, feat_cache):
    method temporal_conv (line 105) | def temporal_conv(self, x, is_first_frame, feat_cache):
    method init_weight (line 127) | def init_weight(self, conv):
    method init_weight2 (line 139) | def init_weight2(self, conv):
  class VAEDecoderWrapperSingle (line 151) | class VAEDecoderWrapperSingle(nn.Module):
    method __init__ (line 152) | def __init__(self):
    method forward (line 168) | def forward(
  class VAEDecoder3d (line 199) | class VAEDecoder3d(nn.Module):
    method __init__ (line 200) | def __init__(self,
    method forward (line 254) | def forward(
  class VAETRTWrapper (line 318) | class VAETRTWrapper():
    method __init__ (line 319) | def __init__(self):
    method quantize_if_needed (line 376) | def quantize_if_needed(self, t, expected_dtype, scale):
    method forward (line 381) | def forward(self, *test_inputs):

FILE: demo_utils/vae_block3.py
  class Resample (line 9) | class Resample(nn.Module):
    method __init__ (line 11) | def __init__(self, dim, mode):
    method forward (line 45) | def forward(self, x, feat_cache=None, feat_idx=[0]):
    method init_weight (line 106) | def init_weight(self, conv):
    method init_weight2 (line 118) | def init_weight2(self, conv):
  class VAEDecoderWrapper (line 130) | class VAEDecoderWrapper(nn.Module):
    method __init__ (line 131) | def __init__(self):
    method forward (line 147) | def forward(
  class VAEDecoder3d (line 187) | class VAEDecoder3d(nn.Module):
    method __init__ (line 188) | def __init__(self,
    method forward (line 242) | def forward(

FILE: demo_utils/vae_torch2trt.py
  function set_workspace (line 98) | def set_workspace(config, bytes_):
  function set_workspace (line 122) | def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
  class VAECalibrator (line 139) | class VAECalibrator(trt.IInt8EntropyCalibrator2):
    method __init__ (line 140) | def __init__(self, loader, cache="calibration.cache", max_batches=10):
    method get_batch_size (line 151) | def get_batch_size(self):
    method getBatchSize (line 154) | def getBatchSize(self):
    method get_batch (line 157) | def get_batch(self, names):
    method read_calibration_cache (line 202) | def read_calibration_cache(self):
    method readCalibrationCache (line 209) | def readCalibrationCache(self):
    method write_calibration_cache (line 212) | def write_calibration_cache(self, cache):
    method writeCalibrationCache (line 216) | def writeCalibrationCache(self, cache):

FILE: get_causal_ode_data_chunkwise.py
  function init_model (line 13) | def init_model(device):
  function main (line 32) | def main():

FILE: get_causal_ode_data_framewise.py
  function init_model (line 13) | def init_model(device):
  function main (line 32) | def main():

FILE: get_causal_ode_data_kv_optimized.py
  function normalize_generator_state_dict (line 31) | def normalize_generator_state_dict(state_dict: dict) -> dict:
  function init_model (line 47) | def init_model(
  function prepare_clean_latent (line 73) | def prepare_clean_latent(
  function main (line 92) | def main():

FILE: inference.py
  function encode (line 125) | def encode(self, videos: torch.Tensor) -> torch.Tensor:

FILE: long_video/app.py
  function _ensure_gpu (line 22) | def _ensure_gpu():
  function _load_pipeline (line 29) | def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use...
  function build_predict (line 78) | def build_predict(config_path: str, checkpoint_path: Optional[str], outp...
  function main (line 120) | def main():

FILE: long_video/inference.py
  function remove_fsdp_prefix (line 70) | def remove_fsdp_prefix(state_dict):
  function encode (line 114) | def encode(self, videos: torch.Tensor) -> torch.Tensor:

FILE: long_video/model/base.py
  class BaseModel (line 12) | class BaseModel(nn.Module):
    method __init__ (line 13) | def __init__(self, args, device):
    method _initialize_models (line 26) | def _initialize_models(self, args, device):
    method _get_timestep (line 48) | def _get_timestep(
  class RollingForcingModel (line 98) | class RollingForcingModel(BaseModel):
    method __init__ (line 99) | def __init__(self, args, device):
    method _run_generator (line 103) | def _run_generator(
    method _consistency_backward_simulation (line 182) | def _consistency_backward_simulation(
    method _initialize_inference_pipeline (line 214) | def _initialize_inference_pipeline(self):

FILE: long_video/model/causvid.py
  class CausVid (line 8) | class CausVid(BaseModel):
    method __init__ (line 9) | def __init__(self, args, device):
    method _compute_kl_grad (line 47) | def _compute_kl_grad(
    method compute_distribution_matching_loss (line 121) | def compute_distribution_matching_loss(
    method _run_generator (line 184) | def _run_generator(
    method generator_loss (line 255) | def generator_loss(
    method critic_loss (line 296) | def critic_loss(

FILE: long_video/model/diffusion.py
  class CausalDiffusion (line 8) | class CausalDiffusion(BaseModel):
    method __init__ (line 9) | def __init__(self, args, device):
    method _initialize_models (line 34) | def _initialize_models(self, args):
    method generator_loss (line 44) | def generator_loss(

FILE: long_video/model/dmd.py
  class DMD (line 9) | class DMD(RollingForcingModel):
    method __init__ (line 10) | def __init__(self, args, device):
    method _compute_kl_grad (line 54) | def _compute_kl_grad(
    method compute_distribution_matching_loss (line 128) | def compute_distribution_matching_loss(
    method generator_loss (line 196) | def generator_loss(
    method critic_loss (line 237) | def critic_loss(

FILE: long_video/model/gan.py
  class GAN (line 10) | class GAN(RollingForcingModel):
    method __init__ (line 11) | def __init__(self, args, device):
    method _run_cls_pred_branch (line 69) | def _run_cls_pred_branch(self,
    method generator_loss (line 90) | def generator_loss(
    method critic_loss (line 174) | def critic_loss(

FILE: long_video/model/ode_regression.py
  class ODERegression (line 9) | class ODERegression(BaseModel):
    method __init__ (line 10) | def __init__(self, args, device):
    method _initialize_models (line 45) | def _initialize_models(self, args, device):
    method _prepare_generator_input (line 59) | def _prepare_generator_input(self, ode_latent: torch.Tensor, tf=False,...
    method generator_loss (line 95) | def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: d...

FILE: long_video/model/sid.py
  class SiD (line 8) | class SiD(RollingForcingModel):
    method __init__ (line 9) | def __init__(self, args, device):
    method compute_distribution_matching_loss (line 47) | def compute_distribution_matching_loss(
    method generator_loss (line 147) | def generator_loss(
    method critic_loss (line 188) | def critic_loss(

FILE: long_video/pipeline/bidirectional_diffusion_inference.py
  class BidirectionalDiffusionInferencePipeline (line 10) | class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method inference (line 34) | def inference(
    method _initialize_sample_scheduler (line 89) | def _initialize_sample_scheduler(self, noise):

FILE: long_video/pipeline/bidirectional_inference.py
  class BidirectionalInferencePipeline (line 7) | class BidirectionalInferencePipeline(torch.nn.Module):
    method __init__ (line 8) | def __init__(
    method inference (line 33) | def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> t...

FILE: long_video/pipeline/causal_diffusion_inference.py
  class CausalDiffusionInferencePipeline (line 10) | class CausalDiffusionInferencePipeline(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method inference (line 49) | def inference(
    method _initialize_kv_cache (line 270) | def _initialize_kv_cache(self, batch_size, dtype, device):
    method _initialize_crossattn_cache (line 300) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
    method _initialize_sample_scheduler (line 321) | def _initialize_sample_scheduler(self, noise):

FILE: long_video/pipeline/rolling_forcing_inference.py
  class CausalInferencePipeline (line 7) | class CausalInferencePipeline(torch.nn.Module):
    method __init__ (line 8) | def __init__(
    method inference_rolling_forcing (line 45) | def inference_rolling_forcing(
    method _initialize_kv_cache (line 338) | def _initialize_kv_cache(self, batch_size, dtype, device):
    method _initialize_crossattn_cache (line 360) | def _initialize_crossattn_cache(self, batch_size, dtype, device):

FILE: long_video/pipeline/rolling_forcing_training.py
  class RollingForcingTrainingPipeline (line 8) | class RollingForcingTrainingPipeline:
    method __init__ (line 9) | def __init__(self,
    method generate_and_sync_list (line 41) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
    method generate_list (line 60) | def generate_list(self, num_blocks, num_denoising_steps, device):
    method inference_with_rolling_forcing (line 75) | def inference_with_rolling_forcing(
    method inference_with_self_forcing (line 256) | def inference_with_self_forcing(
    method _initialize_kv_cache (line 436) | def _initialize_kv_cache(self, batch_size, dtype, device):
    method _initialize_crossattn_cache (line 452) | def _initialize_crossattn_cache(self, batch_size, dtype, device):

FILE: long_video/train.py
  function main (line 8) | def main():

FILE: long_video/trainer/diffusion.py
  class Trainer (line 17) | class Trainer:
    method __init__ (line 18) | def __init__(self, config):
    method save (line 140) | def save(self):
    method train_one_step (line 163) | def train_one_step(self, batch):
    method generate_video (line 235) | def generate_video(self, pipeline, prompts, image=None):
    method train (line 248) | def train(self):

FILE: long_video/trainer/distillation.py
  class Trainer (line 20) | class Trainer:
    method __init__ (line 21) | def __init__(self, config):
    method save (line 176) | def save(self):
    method fwdbwd_one_step (line 203) | def fwdbwd_one_step(self, batch, train_generator):
    method generate_video (line 276) | def generate_video(self, pipeline, prompts, image=None):
    method train (line 306) | def train(self):

FILE: long_video/trainer/gan.py
  class Trainer (line 19) | class Trainer:
    method __init__ (line 20) | def __init__(self, config):
    method save (line 208) | def save(self):
    method fwdbwd_one_step (line 235) | def fwdbwd_one_step(self, batch, train_generator):
    method generate_video (line 324) | def generate_video(self, pipeline, prompts, image=None):
    method train (line 337) | def train(self):
    method all_gather_dict (line 457) | def all_gather_dict(self, target_dict):

FILE: long_video/trainer/ode.py
  class Trainer (line 19) | class Trainer:
    method __init__ (line 20) | def __init__(self, config):
    method save (line 118) | def save(self):
    method train_one_step (line 134) | def train_one_step(self):
    method train (line 225) | def train(self):

FILE: long_video/utils/dataset.py
  class TextDataset (line 12) | class TextDataset(Dataset):
    method __init__ (line 13) | def __init__(self, prompt_path, extended_prompt_path=None):
    method __len__ (line 24) | def __len__(self):
    method __getitem__ (line 27) | def __getitem__(self, idx):
  class ODERegressionLMDBDataset (line 37) | class ODERegressionLMDBDataset(Dataset):
    method __init__ (line 38) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
    method __len__ (line 45) | def __len__(self):
    method __getitem__ (line 48) | def __getitem__(self, idx):
  class ShardingLMDBDataset (line 72) | class ShardingLMDBDataset(Dataset):
    method __init__ (line 73) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
    method __len__ (line 96) | def __len__(self):
    method __getitem__ (line 99) | def __getitem__(self, idx):
  class TextImagePairDataset (line 127) | class TextImagePairDataset(Dataset):
    method __init__ (line 128) | def __init__(
    method __len__ (line 182) | def __len__(self):
    method __getitem__ (line 185) | def __getitem__(self, idx):
  function cycle (line 217) | def cycle(dl):

FILE: long_video/utils/distributed.py
  function fsdp_state_dict (line 11) | def fsdp_state_dict(model):
  function fsdp_wrap (line 23) | def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, w...
  function barrier (line 70) | def barrier():
  function launch_distributed_job (line 75) | def launch_distributed_job(backend: str = "nccl"):
  class EMA_FSDP (line 91) | class EMA_FSDP:
    method __init__ (line 92) | def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
    method _init_shadow (line 98) | def _init_shadow(self, fsdp_module):
    method update (line 105) | def update(self, fsdp_module):
    method state_dict (line 113) | def state_dict(self):
    method load_state_dict (line 116) | def load_state_dict(self, sd):
    method copy_to (line 119) | def copy_to(self, fsdp_module):

FILE: long_video/utils/lmdb.py
  function get_array_shape_from_lmdb (line 4) | def get_array_shape_from_lmdb(env, array_name):
  function store_arrays_to_lmdb (line 11) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
  function process_data_dict (line 30) | def process_data_dict(data_dict, seen_prompts):
  function retrieve_row_from_lmdb (line 56) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...

FILE: long_video/utils/loss.py
  class DenoisingLoss (line 5) | class DenoisingLoss(ABC):
    method __call__ (line 7) | def __call__(
  class X0PredLoss (line 27) | class X0PredLoss(DenoisingLoss):
    method __call__ (line 28) | def __call__(
  class VPredLoss (line 38) | class VPredLoss(DenoisingLoss):
    method __call__ (line 39) | def __call__(
  class NoisePredLoss (line 50) | class NoisePredLoss(DenoisingLoss):
    method __call__ (line 51) | def __call__(
  class FlowPredLoss (line 61) | class FlowPredLoss(DenoisingLoss):
    method __call__ (line 62) | def __call__(
  function get_denoising_loss (line 80) | def get_denoising_loss(loss_type: str) -> DenoisingLoss:

FILE: long_video/utils/misc.py
  function set_seed (line 6) | def set_seed(seed: int, deterministic: bool = False):
  function merge_dict_list (line 25) | def merge_dict_list(dict_list):

FILE: long_video/utils/scheduler.py
  class SchedulerInterface (line 5) | class SchedulerInterface(ABC):
    method add_noise (line 12) | def add_noise(
    method convert_x0_to_noise (line 26) | def convert_x0_to_noise(
    method convert_noise_to_x0 (line 52) | def convert_noise_to_x0(
    method convert_velocity_to_x0 (line 77) | def convert_velocity_to_x0(
  class FlowMatchScheduler (line 106) | class FlowMatchScheduler():
    method __init__ (line 108) | def __init__(self, num_inference_steps=100, num_train_timesteps=1000, ...
    method set_timesteps (line 118) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
    method step (line 143) | def step(self, model_output, timestep, sample, to_final=False):
    method add_noise (line 159) | def add_noise(self, original_samples, noise, timestep):
    method training_target (line 178) | def training_target(self, sample, noise, timestep):
    method training_weight (line 182) | def training_weight(self, timestep):

FILE: long_video/utils/wan_wrapper.py
  class WanTextEncoder (line 14) | class WanTextEncoder(torch.nn.Module):
    method __init__ (line 15) | def __init__(self) -> None:
    method device (line 33) | def device(self):
    method forward (line 37) | def forward(self, text_prompts: List[str]) -> dict:
  class WanVAEWrapper (line 53) | class WanVAEWrapper(torch.nn.Module):
    method __init__ (line 54) | def __init__(self):
    method encode_to_latent (line 73) | def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
    method decode_to_pixel (line 89) | def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = Fals...
  class WanDiffusionWrapper (line 115) | class WanDiffusionWrapper(torch.nn.Module):
    method __init__ (line 116) | def __init__(
    method enable_gradient_checkpointing (line 144) | def enable_gradient_checkpointing(self) -> None:
    method adding_cls_branch (line 147) | def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_di...
    method _convert_flow_pred_to_x0 (line 169) | def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch....
    method _convert_x0_to_flow_pred (line 196) | def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: tor...
    method forward (line 218) | def forward(
    method get_scheduler (line 293) | def get_scheduler(self) -> SchedulerInterface:
    method post_init (line 307) | def post_init(self):

FILE: long_video/wan/distributed/fsdp.py
  function shard_model (line 10) | def shard_model(

FILE: long_video/wan/distributed/xdit_context_parallel.py
  function pad_freqs (line 12) | def pad_freqs(original_tensor, target_len):
  function rope_apply (line 26) | def rope_apply(x, grid_sizes, freqs):
  function usp_dit_forward (line 66) | def usp_dit_forward(
  function usp_attn_forward (line 149) | def usp_attn_forward(self,

FILE: long_video/wan/image2video.py
  class WanI2V (line 29) | class WanI2V:
    method __init__ (line 31) | def __init__(
    method generate (line 129) | def generate(self,

FILE: long_video/wan/modules/attention.py
  function is_hopper_gpu (line 7) | def is_hopper_gpu():
  function flash_attention (line 32) | def flash_attention(
  function attention (line 139) | def attention(

FILE: long_video/wan/modules/causal_model.py
  function causal_rope_apply (line 27) | def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
  class CausalWanSelfAttention (line 58) | class CausalWanSelfAttention(nn.Module):
    method __init__ (line 60) | def __init__(self,
    method forward (line 87) | def forward(
  class CausalWanAttentionBlock (line 309) | class CausalWanAttentionBlock(nn.Module):
    method __init__ (line 311) | def __init__(self,
    method forward (line 349) | def forward(
  class CausalHead (line 405) | class CausalHead(nn.Module):
    method __init__ (line 407) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
    method forward (line 422) | def forward(self, x, e):
  class CausalWanModel (line 436) | class CausalWanModel(ModelMixin, ConfigMixin):
    method __init__ (line 448) | def __init__(self,
    method _set_gradient_checkpointing (line 569) | def _set_gradient_checkpointing(self, module, value=False):
    method _prepare_blockwise_causal_attn_mask (line 573) | def _prepare_blockwise_causal_attn_mask(
    method _prepare_teacher_forcing_mask (line 631) | def _prepare_teacher_forcing_mask(
    method _prepare_blockwise_causal_attn_mask_i2v (line 719) | def _prepare_blockwise_causal_attn_mask_i2v(
    method _forward_inference (line 779) | def _forward_inference(
    method _forward_train (line 912) | def _forward_train(
    method forward (line 1070) | def forward(
    method unpatchify (line 1080) | def unpatchify(self, x, grid_sizes):
    method init_weights (line 1105) | def init_weights(self):

FILE: long_video/wan/modules/clip.py
  function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
  class QuickGELU (line 41) | class QuickGELU(nn.Module):
    method forward (line 43) | def forward(self, x):
  class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
    method forward (line 49) | def forward(self, x):
  class SelfAttention (line 53) | class SelfAttention(nn.Module):
    method __init__ (line 55) | def __init__(self,
    method forward (line 74) | def forward(self, x):
  class SwiGLU (line 94) | class SwiGLU(nn.Module):
    method __init__ (line 96) | def __init__(self, dim, mid_dim):
    method forward (line 106) | def forward(self, x):
  class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
    method __init__ (line 114) | def __init__(self,
    method forward (line 146) | def forward(self, x):
  class AttentionPool (line 156) | class AttentionPool(nn.Module):
    method __init__ (line 158) | def __init__(self,
    method forward (line 186) | def forward(self, x):
  class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
    method __init__ (line 211) | def __init__(self,
    method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
  class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
    method __init__ (line 305) | def __init__(self, **kwargs):
    method forward (line 315) | def forward(self, ids):
  class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
    method __init__ (line 330) | def __init__(self,
    method forward (line 406) | def forward(self, imgs, txt_ids):
    method param_groups (line 418) | def param_groups(self):
  function _clip (line 434) | def _clip(pretrained=False,
  function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
  class CLIPModel (line 501) | class CLIPModel:
    method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
    method visual (line 527) | def visual(self, videos):

FILE: long_video/wan/modules/model.py
  function sinusoidal_embedding_1d (line 15) | def sinusoidal_embedding_1d(dim, position):
  function rope_params (line 29) | def rope_params(max_seq_len, dim, theta=10000):
  function rope_apply (line 40) | def rope_apply(x, grid_sizes, freqs):
  class WanRMSNorm (line 70) | class WanRMSNorm(nn.Module):
    method __init__ (line 72) | def __init__(self, dim, eps=1e-5):
    method forward (line 78) | def forward(self, x):
    method _norm (line 85) | def _norm(self, x):
  class WanLayerNorm (line 89) | class WanLayerNorm(nn.LayerNorm):
    method __init__ (line 91) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
    method forward (line 94) | def forward(self, x):
  class WanSelfAttention (line 102) | class WanSelfAttention(nn.Module):
    method __init__ (line 104) | def __init__(self,
    method forward (line 127) | def forward(self, x, seq_lens, grid_sizes, freqs):
  class WanT2VCrossAttention (line 159) | class WanT2VCrossAttention(WanSelfAttention):
    method forward (line 161) | def forward(self, x, context, context_lens, crossattn_cache=None):
  class WanGanCrossAttention (line 197) | class WanGanCrossAttention(WanSelfAttention):
    method forward (line 199) | def forward(self, x, context, crossattn_cache=None):
  class WanI2VCrossAttention (line 224) | class WanI2VCrossAttention(WanSelfAttention):
    method __init__ (line 226) | def __init__(self,
    method forward (line 240) | def forward(self, x, context, context_lens):
  class WanAttentionBlock (line 275) | class WanAttentionBlock(nn.Module):
    method __init__ (line 277) | def __init__(self,
    method forward (line 315) | def forward(
  class GanAttentionBlock (line 357) | class GanAttentionBlock(nn.Module):
    method __init__ (line 359) | def __init__(self,
    method forward (line 397) | def forward(
  class Head (line 439) | class Head(nn.Module):
    method __init__ (line 441) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
    method forward (line 456) | def forward(self, x, e):
  class MLPProj (line 469) | class MLPProj(torch.nn.Module):
    method __init__ (line 471) | def __init__(self, in_dim, out_dim):
    method forward (line 479) | def forward(self, image_embeds):
  class RegisterTokens (line 484) | class RegisterTokens(nn.Module):
    method __init__ (line 485) | def __init__(self, num_registers: int, dim: int):
    method forward (line 490) | def forward(self):
    method reset_parameters (line 493) | def reset_parameters(self):
  class WanModel (line 497) | class WanModel(ModelMixin, ConfigMixin):
    method __init__ (line 509) | def __init__(self,
    method _set_gradient_checkpointing (line 623) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 626) | def forward(
    method _forward (line 637) | def _forward(
    method _forward_classify (line 773) | def _forward_classify(
    method unpatchify (line 876) | def unpatchify(self, x, grid_sizes, c=None):
    method init_weights (line 901) | def init_weights(self):

FILE: long_video/wan/modules/t5.py
  function fp16_clamp (line 20) | def fp16_clamp(x):
  function init_weights (line 27) | def init_weights(m):
  class GELU (line 46) | class GELU(nn.Module):
    method forward (line 48) | def forward(self, x):
  class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
    method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
    method forward (line 61) | def forward(self, x):
  class T5Attention (line 69) | class T5Attention(nn.Module):
    method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
    method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
  class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
    method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
    method forward (line 136) | def forward(self, x):
  class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
    method __init__ (line 146) | def __init__(self,
    method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
  class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
    method __init__ (line 180) | def __init__(self,
    method forward (line 206) | def forward(self,
  class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
    method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
    method forward (line 233) | def forward(self, lq, lk):
    method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
  class T5Encoder (line 267) | class T5Encoder(nn.Module):
    method __init__ (line 269) | def __init__(self,
    method forward (line 303) | def forward(self, ids, mask=None):
  class T5Decoder (line 315) | class T5Decoder(nn.Module):
    method __init__ (line 317) | def __init__(self,
    method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
  class T5Model (line 372) | class T5Model(nn.Module):
    method __init__ (line 374) | def __init__(self,
    method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
  function _t5 (line 415) | def _t5(name,
  function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
  class T5EncoderModel (line 472) | class T5EncoderModel:
    method __init__ (line 474) | def __init__(
    method __call__ (line 506) | def __call__(self, texts, device):

FILE: long_video/wan/modules/tokenizers.py
  function basic_clean (line 12) | def basic_clean(text):
  function whitespace_clean (line 18) | def whitespace_clean(text):
  function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
  class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
    method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
    method __call__ (line 49) | def __call__(self, sequence, **kwargs):
    method _clean (line 75) | def _clean(self, text):

FILE: long_video/wan/modules/vae.py
  class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
    method __init__ (line 22) | def __init__(self, *args, **kwargs):
    method forward (line 28) | def forward(self, x, cache_x=None):
  class RMS_norm (line 39) | class RMS_norm(nn.Module):
    method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
    method forward (line 51) | def forward(self, x):
  class Upsample (line 57) | class Upsample(nn.Upsample):
    method forward (line 59) | def forward(self, x):
  class Resample (line 66) | class Resample(nn.Module):
    method __init__ (line 68) | def __init__(self, dim, mode):
    method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
    method init_weight (line 162) | def init_weight(self, conv):
    method init_weight2 (line 174) | def init_weight2(self, conv):
  class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
    method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
    method __init__ (line 228) | def __init__(self, dim):
    method forward (line 240) | def forward(self, x):
  class Encoder3d (line 265) | class Encoder3d(nn.Module):
    method __init__ (line 267) | def __init__(self,
    method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class Decoder3d (line 369) | class Decoder3d(nn.Module):
    method __init__ (line 371) | def __init__(self,
    method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  function count_conv3d (line 475) | def count_conv3d(model):
  class WanVAE_ (line 483) | class WanVAE_(nn.Module):
    method __init__ (line 485) | def __init__(self,
    method forward (line 511) | def forward(self, x):
    method encode (line 517) | def encode(self, x, scale):
    method decode (line 545) | def decode(self, z, scale):
    method cached_decode (line 571) | def cached_decode(self, z, scale):
    method sample (line 595) | def sample(self, imgs, deterministic=False):
    method clear_cache (line 602) | def clear_cache(self):
  function _video_vae (line 612) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
  class WanVAE (line 639) | class WanVAE:
    method __init__ (line 641) | def __init__(self,
    method encode (line 667) | def encode(self, videos):
    method decode (line 677) | def decode(self, zs):

FILE: long_video/wan/modules/xlm_roberta.py
  class SelfAttention (line 10) | class SelfAttention(nn.Module):
    method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
    method forward (line 27) | def forward(self, x, mask):
  class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
    method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
    method forward (line 66) | def forward(self, x, mask):
  class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
    method __init__ (line 81) | def __init__(self,
    method forward (line 118) | def forward(self, ids):
  function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,

FILE: long_video/wan/text2video.py
  class WanT2V (line 26) | class WanT2V:
    method __init__ (line 28) | def __init__(
    method generate (line 110) | def generate(self,

FILE: long_video/wan/utils/fm_solvers.py
  function get_sampling_sigmas (line 22) | def get_sampling_sigmas(sampling_steps, shift):
  function retrieve_timesteps (line 29) | def retrieve_timesteps(
  class FlowDPMSolverMultistepScheduler (line 69) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 129) | def __init__(
    method step_index (line 202) | def step_index(self):
    method begin_index (line 209) | def begin_index(self):
    method set_begin_index (line 216) | def set_begin_index(self, begin_index: int = 0):
    method set_timesteps (line 226) | def set_timesteps(
    method _threshold_sample (line 292) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
    method _sigma_to_t (line 330) | def _sigma_to_t(self, sigma):
    method _sigma_to_alpha_sigma_t (line 333) | def _sigma_to_alpha_sigma_t(self, sigma):
    method time_shift (line 337) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
    method convert_model_output (line 341) | def convert_model_output(
    method dpm_solver_first_order_update (line 415) | def dpm_solver_first_order_update(
    method multistep_dpm_solver_second_order_update (line 486) | def multistep_dpm_solver_second_order_update(
    method multistep_dpm_solver_third_order_update (line 596) | def multistep_dpm_solver_third_order_update(
    method index_for_timestep (line 679) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 693) | def _init_step_index(self, timestep):
    method step (line 706) | def step(
    method scale_model_input (line 800) | def scale_model_input(self, sample: torch.Tensor, *args,
    method add_noise (line 815) | def add_noise(
    method __len__ (line 856) | def __len__(self):

FILE: long_video/wan/utils/fm_solvers_unipc.py
  class FlowUniPCMultistepScheduler (line 20) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 77) | def __init__(
    method step_index (line 135) | def step_index(self):
    method begin_index (line 142) | def begin_index(self):
    method set_begin_index (line 149) | def set_begin_index(self, begin_index: int = 0):
    method set_timesteps (line 160) | def set_timesteps(
    method _threshold_sample (line 230) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
    method _sigma_to_t (line 269) | def _sigma_to_t(self, sigma):
    method _sigma_to_alpha_sigma_t (line 272) | def _sigma_to_alpha_sigma_t(self, sigma):
    method time_shift (line 276) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
    method convert_model_output (line 279) | def convert_model_output(
    method multistep_uni_p_bh_update (line 350) | def multistep_uni_p_bh_update(
    method multistep_uni_c_bh_update (line 486) | def multistep_uni_c_bh_update(
    method index_for_timestep (line 628) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 643) | def _init_step_index(self, timestep):
    method step (line 655) | def step(self,
    method scale_model_input (line 741) | def scale_model_input(self, sample: torch.Tensor, *args,
    method add_noise (line 758) | def add_noise(
    method __len__ (line 799) | def __len__(self):

FILE: long_video/wan/utils/prompt_extend.py
  class PromptOutput (line 101) | class PromptOutput(object):
    method add_custom_field (line 108) | def add_custom_field(self, key: str, value) -> None:
  class PromptExpander (line 112) | class PromptExpander:
    method __init__ (line 114) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
    method extend_with_img (line 119) | def extend_with_img(self,
    method extend (line 128) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method decide_system_prompt (line 131) | def decide_system_prompt(self, tar_lang="ch"):
    method __call__ (line 138) | def __call__(self,
  class DashScopePromptExpander (line 157) | class DashScopePromptExpander(PromptExpander):
    method __init__ (line 159) | def __init__(self,
    method extend (line 196) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method extend_with_img (line 232) | def extend_with_img(self,
  class QwenPromptExpander (line 300) | class QwenPromptExpander(PromptExpander):
    method __init__ (line 309) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
    method extend (line 366) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method extend_with_img (line 397) | def extend_with_img(self,

FILE: long_video/wan/utils/qwen_vl_utils.py
  function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
  function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
  function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
  function smart_resize (line 54) | def smart_resize(height: int,
  function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
  function smart_nframes (line 133) | def smart_nframes(
  function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
  function is_decord_available (line 215) | def is_decord_available() -> bool:
  function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
  function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
  function fetch_video (line 274) | def fetch_video(
  function extract_vision_info (line 328) | def extract_vision_info(
  function process_vision_info (line 344) | def process_vision_info(

FILE: long_video/wan/utils/utils.py
  function rand_name (line 14) | def rand_name(length=8, suffix=''):
  function cache_video (line 23) | def cache_video(tensor,
  function cache_image (line 64) | def cache_image(tensor,
  function str2bool (line 94) | def str2bool(v):

FILE: model/base.py
  class BaseModel (line 12) | class BaseModel(nn.Module):
    method __init__ (line 13) | def __init__(self, args, device):
    method _initialize_models (line 40) | def _initialize_models(self, args, device):
    method _get_timestep (line 63) | def _get_timestep(
  class SelfForcingModel (line 113) | class SelfForcingModel(BaseModel):
    method __init__ (line 114) | def __init__(self, args, device):
    method _run_generator (line 118) | def _run_generator(
    method _consistency_backward_simulation (line 205) | def _consistency_backward_simulation(
    method _initialize_inference_pipeline (line 229) | def _initialize_inference_pipeline(self):
  class TeacherForcingModel (line 251) | class TeacherForcingModel(BaseModel):
    method __init__ (line 252) | def __init__(self, args, device):
    method _run_generator (line 256) | def _run_generator(
    method _consistency_backward_simulation_tf (line 344) | def _consistency_backward_simulation_tf(
    method _initialize_inference_pipeline_tf (line 372) | def _initialize_inference_pipeline_tf(self):
  class BidirectionalModel (line 391) | class BidirectionalModel(BaseModel):
    method __init__ (line 392) | def __init__(self, args, device):
    method _run_generator (line 396) | def _run_generator(
    method _consistency_backward_simulation_bidirectional (line 476) | def _consistency_backward_simulation_bidirectional(
    method _initialize_inference_pipeline_bidirectional (line 502) | def _initialize_inference_pipeline_bidirectional(self):

FILE: model/causvid.py
  class CausVid (line 8) | class CausVid(BaseModel):
    method __init__ (line 9) | def __init__(self, args, device):
    method _compute_kl_grad (line 47) | def _compute_kl_grad(
    method compute_distribution_matching_loss (line 121) | def compute_distribution_matching_loss(
    method _run_generator (line 184) | def _run_generator(
    method generator_loss (line 256) | def generator_loss(
    method critic_loss (line 297) | def critic_loss(

FILE: model/diffusion.py
  class CausalDiffusion (line 8) | class CausalDiffusion(BaseModel):
    method __init__ (line 9) | def __init__(self, args, device):
    method _initialize_models (line 35) | def _initialize_models(self, args, device):
    method generator_loss (line 48) | def generator_loss(

FILE: model/dmd.py
  class DMD (line 9) | class DMD(SelfForcingModel):
    method __init__ (line 10) | def __init__(self, args, device):
    method _compute_kl_grad (line 56) | def _compute_kl_grad(
    method compute_distribution_matching_loss (line 130) | def compute_distribution_matching_loss(
    method generator_loss (line 199) | def generator_loss(
    method critic_loss (line 240) | def critic_loss(
    method _prepare_generator_input (line 338) | def _prepare_generator_input(self, ode_latent: torch.Tensor, tf=False,...

FILE: model/gan.py
  class GAN (line 10) | class GAN(SelfForcingModel):
    method __init__ (line 11) | def __init__(self, args, device):
    method _run_cls_pred_branch (line 69) | def _run_cls_pred_branch(self,
    method generator_loss (line 90) | def generator_loss(
    method critic_loss (line 174) | def critic_loss(

FILE: model/naive_consistency.py
  class NaiveConsistency (line 9) | class NaiveConsistency(BaseModel):
    method __init__ (line 10) | def __init__(self, args, device):
    method _initialize_models (line 66) | def _initialize_models(self, args, device):
    method generator_loss (line 85) | def generator_loss(

FILE: model/ode_regression.py
  class ODERegression (line 9) | class ODERegression(BaseModel):
    method __init__ (line 10) | def __init__(self, args, device):
    method _initialize_models (line 45) | def _initialize_models(self, args, device):
    method _prepare_generator_input (line 59) | def _prepare_generator_input(self, ode_latent: torch.Tensor, tf=False,...
    method generator_loss (line 95) | def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: d...

FILE: model/sid.py
  class SiD (line 8) | class SiD(SelfForcingModel):
    method __init__ (line 9) | def __init__(self, args, device):
    method compute_distribution_matching_loss (line 47) | def compute_distribution_matching_loss(
    method generator_loss (line 147) | def generator_loss(
    method critic_loss (line 188) | def critic_loss(

FILE: pipeline/bidirectional_diffusion_inference.py
  class BidirectionalDiffusionInferencePipeline (line 10) | class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method inference (line 34) | def inference(
    method _initialize_sample_scheduler (line 89) | def _initialize_sample_scheduler(self, noise):

FILE: pipeline/bidirectional_inference.py
  class BidirectionalInferencePipeline (line 7) | class BidirectionalInferencePipeline(torch.nn.Module):
    method __init__ (line 8) | def __init__(
    method inference (line 33) | def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> t...

FILE: pipeline/bidirectional_training.py
  class BidirectionalTrainingPipeline (line 8) | class BidirectionalTrainingPipeline:
    method __init__ (line 9) | def __init__(self,
    method generate_and_sync_list (line 44) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
    method inference_with_trajectory (line 64) | def inference_with_trajectory(

FILE: pipeline/causal_diffusion_inference.py
  class CausalDiffusionInferencePipeline (line 10) | class CausalDiffusionInferencePipeline(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method inference (line 51) | def inference(
    method inference_for_cd (line 277) | def inference_for_cd(
    method inference_for_genuine_cd (line 474) | def inference_for_genuine_cd(
    method _initialize_kv_cache (line 605) | def _initialize_kv_cache(self, batch_size, dtype, device):
    method _initialize_crossattn_cache (line 635) | def _initialize_crossattn_cache(self, batch_size, dtype, device):
    method _initialize_sample_scheduler (line 656) | def _initialize_sample_scheduler(self, noise, sampling_steps=-1):

FILE: pipeline/causal_inference.py
  class CausalInferencePipeline (line 10) | class CausalInferencePipeline(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method inference (line 65) | def inference(
    method _initialize_kv_cache (line 337) | def _initialize_kv_cache(self, batch_size, dtype, device):
    method _initialize_crossattn_cache (line 359) | def _initialize_crossattn_cache(self, batch_size, dtype, device):

FILE: pipeline/self_forcing_training.py
  class SelfForcingTrainingPipeline (line 8) | class SelfForcingTrainingPipeline:
    method __init__ (line 9) | def __init__(self,
    method generate_and_sync_list (line 48) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
    method inference_with_trajectory (line 68) | def inference_with_trajectory(
    method _initialize_kv_cache (line 288) | def _initialize_kv_cache(self, batch_size, dtype, device):
    method _initialize_crossattn_cache (line 304) | def _initialize_crossattn_cache(self, batch_size, dtype, device):

FILE: pipeline/teacher_forcing_training.py
  class TeacherForcingTrainingPipeline (line 8) | class TeacherForcingTrainingPipeline:
    method __init__ (line 9) | def __init__(self,
    method generate_and_sync_list (line 44) | def generate_and_sync_list(self, num_blocks, num_denoising_steps, devi...
    method inference_with_trajectory (line 64) | def inference_with_trajectory(

FILE: train.py
  function main (line 9) | def main():

FILE: trainer/diffusion.py
  class Trainer (line 20) | class Trainer:
    method __init__ (line 21) | def __init__(self, config):
    method save (line 168) | def save(self):
    method train_one_step (line 191) | def train_one_step(self, batch):
    method train (line 261) | def train(self):

FILE: trainer/distillation.py
  class Trainer (line 16) | class Trainer:
    method __init__ (line 17) | def __init__(self, config):
    method save (line 188) | def save(self):
    method save_critic (line 212) | def save_critic(self):
    method fwdbwd_one_step (line 229) | def fwdbwd_one_step(self, batch, train_generator, clean_latent=None):
    method train (line 303) | def train(self):

FILE: trainer/gan.py
  class Trainer (line 19) | class Trainer:
    method __init__ (line 20) | def __init__(self, config):
    method save (line 208) | def save(self):
    method fwdbwd_one_step (line 235) | def fwdbwd_one_step(self, batch, train_generator):
    method generate_video (line 324) | def generate_video(self, pipeline, prompts, image=None):
    method train (line 337) | def train(self):
    method all_gather_dict (line 457) | def all_gather_dict(self, target_dict):

FILE: trainer/naive_cd.py
  class Trainer (line 20) | class Trainer:
    method __init__ (line 21) | def __init__(self, config):
    method save (line 183) | def save(self):
    method fwdbwd_one_step (line 206) | def fwdbwd_one_step(self, batch, clean_latent=None):
    method train (line 251) | def train(self):

FILE: trainer/ode.py
  class Trainer (line 19) | class Trainer:
    method __init__ (line 20) | def __init__(self, config):
    method save (line 132) | def save(self):
    method train_one_step (line 148) | def train_one_step(self, loss_scale=1.0):
    method train (line 217) | def train(self):

FILE: utils/create_lmdb_iterative.py
  function store_arrays_to_lmdb (line 10) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
  function get_array_shape_from_lmdb (line 27) | def get_array_shape_from_lmdb(env, array_name):
  function process_data_dict (line 34) | def process_data_dict(data_dict, seen_prompts):
  function retrieve_row_from_lmdb (line 59) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...
  function main (line 78) | def main():

FILE: utils/dataset.py
  class TextDataset (line 12) | class TextDataset(Dataset):
    method __init__ (line 13) | def __init__(self, prompt_path, extended_prompt_path=None):
    method __len__ (line 24) | def __len__(self):
    method __getitem__ (line 27) | def __getitem__(self, idx):
  class ODERegressionLMDBDataset (line 37) | class ODERegressionLMDBDataset(Dataset):
    method __init__ (line 38) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
    method __len__ (line 45) | def __len__(self):
    method __getitem__ (line 48) | def __getitem__(self, idx):
  class LatentLMDBDataset (line 75) | class LatentLMDBDataset(Dataset):
    method __init__ (line 76) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
    method __len__ (line 83) | def __len__(self):
    method __getitem__ (line 86) | def __getitem__(self, idx):
  class ShardingLMDBDataset (line 110) | class ShardingLMDBDataset(Dataset):
    method __init__ (line 111) | def __init__(self, data_path: str, max_pair: int = int(1e8)):
    method __len__ (line 134) | def __len__(self):
    method __getitem__ (line 137) | def __getitem__(self, idx):
  class TextImagePairDataset (line 166) | class TextImagePairDataset(Dataset):
    method __init__ (line 167) | def __init__(
    method __len__ (line 221) | def __len__(self):
    method __getitem__ (line 224) | def __getitem__(self, idx):
  function cycle (line 257) | def cycle(dl):

FILE: utils/distributed.py
  function fsdp_state_dict (line 11) | def fsdp_state_dict(model):
  function fsdp_wrap (line 23) | def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, w...
  function barrier (line 70) | def barrier():
  function launch_distributed_job (line 75) | def launch_distributed_job(backend: str = "nccl"):
  class EMA_FSDP (line 91) | class EMA_FSDP:
    method __init__ (line 92) | def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
    method _init_shadow (line 98) | def _init_shadow(self, fsdp_module):
    method update (line 103) | def update(self, fsdp_module):
    method state_dict (line 109) | def state_dict(self):
    method load_state_dict (line 112) | def load_state_dict(self, sd):
    method copy_to (line 115) | def copy_to(self, fsdp_module):
    method full_state_dict (line 121) | def full_state_dict(self, fsdp_module):

FILE: utils/lmdb_.py
  function get_array_shape_from_lmdb (line 4) | def get_array_shape_from_lmdb(env, array_name):
  function store_arrays_to_lmdb (line 11) | def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
  function process_data_dict (line 30) | def process_data_dict(data_dict, seen_prompts):
  function retrieve_row_from_lmdb (line 56) | def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape...

FILE: utils/loss.py
  class DenoisingLoss (line 5) | class DenoisingLoss(ABC):
    method __call__ (line 7) | def __call__(
  class X0PredLoss (line 27) | class X0PredLoss(DenoisingLoss):
    method __call__ (line 28) | def __call__(
  class VPredLoss (line 38) | class VPredLoss(DenoisingLoss):
    method __call__ (line 39) | def __call__(
  class NoisePredLoss (line 50) | class NoisePredLoss(DenoisingLoss):
    method __call__ (line 51) | def __call__(
  class FlowPredLoss (line 61) | class FlowPredLoss(DenoisingLoss):
    method __call__ (line 62) | def __call__(
  function get_denoising_loss (line 80) | def get_denoising_loss(loss_type: str) -> DenoisingLoss:

FILE: utils/merge_and_get_clean.py
  function read_shape (line 8) | def read_shape(env, name):
  function list_array_names (line 14) | def list_array_names(env):
  function ensure_empty_dir (line 24) | def ensure_empty_dir(path):
  function safe_mapsize (line 28) | def safe_mapsize(env):
  function get_bytes (line 32) | def get_bytes(txn, key):
  function latents_bytes_to_out (line 37) | def latents_bytes_to_out(row_bytes, in_row_shape, out_row_shape):
  function merge_many (line 50) | def merge_many(src_dirs_all, dst_dir):
  function rm_dirs (line 140) | def rm_dirs(dirs, desc="remove dirs"):

FILE: utils/merge_lmdb.py
  function read_shape (line 8) | def read_shape(env, name):
  function list_array_names (line 15) | def list_array_names(env):
  function ensure_empty_dir (line 26) | def ensure_empty_dir(path):
  function safe_mapsize (line 31) | def safe_mapsize(env):
  function get_bytes (line 35) | def get_bytes(txn, key):
  function latents_bytes_to_out (line 41) | def latents_bytes_to_out(row_bytes, in_row_shape, out_row_shape):
  function merge_many (line 47) | def merge_many(src_dirs_all, dst_dir):
  function rm_dirs (line 142) | def rm_dirs(dirs, desc="remove dirs"):

FILE: utils/misc.py
  function set_seed (line 6) | def set_seed(seed: int, deterministic: bool = False):
  function merge_dict_list (line 25) | def merge_dict_list(dict_list):

FILE: utils/ode_generation.py
  function merge_cfg_prompt_embeds (line 6) | def merge_cfg_prompt_embeds(
  function normalize_trajectory_indices (line 21) | def normalize_trajectory_indices(
  class CausalODETrajectoryGenerator (line 37) | class CausalODETrajectoryGenerator:
    method __init__ (line 38) | def __init__(
    method _make_kv_cache (line 55) | def _make_kv_cache(self, batch_size: int, device: torch.device) -> lis...
    method _make_crossattn_cache (line 81) | def _make_crossattn_cache(self, batch_size: int, device: torch.device)...
    method _batched_cfg_step (line 101) | def _batched_cfg_step(
    method _update_clean_cache (line 132) | def _update_clean_cache(
    method _generate_full (line 156) | def _generate_full(
    method _generate_blockwise_kv (line 196) | def _generate_blockwise_kv(
    method _assemble_selected_trajectory (line 267) | def _assemble_selected_trajectory(
    method generate (line 290) | def generate(

FILE: utils/scheduler.py
  class SchedulerInterface (line 5) | class SchedulerInterface(ABC):
    method add_noise (line 12) | def add_noise(
    method convert_x0_to_noise (line 26) | def convert_x0_to_noise(
    method convert_noise_to_x0 (line 52) | def convert_noise_to_x0(
    method convert_velocity_to_x0 (line 77) | def convert_velocity_to_x0(
  class FlowMatchScheduler (line 106) | class FlowMatchScheduler():
    method __init__ (line 108) | def __init__(self, num_inference_steps=100, num_train_timesteps=1000, ...
    method set_timesteps (line 118) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
    method step (line 143) | def step(self, model_output, timestep, sample, to_final=False):
    method add_noise (line 159) | def add_noise(self, original_samples, noise, timestep):
    method training_target (line 178) | def training_target(self, sample, noise, timestep):
    method training_weight (line 182) | def training_weight(self, timestep):

FILE: utils/wan_wrapper.py
  class WanTextEncoder (line 14) | class WanTextEncoder(torch.nn.Module):
    method __init__ (line 15) | def __init__(self) -> None:
    method device (line 33) | def device(self):
    method forward (line 37) | def forward(self, text_prompts: List[str]) -> dict:
  class WanVAEWrapper (line 53) | class WanVAEWrapper(torch.nn.Module):
    method __init__ (line 54) | def __init__(self):
    method encode_to_latent (line 73) | def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
    method decode_to_pixel (line 89) | def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = Fals...
  class WanDiffusionWrapper (line 115) | class WanDiffusionWrapper(torch.nn.Module):
    method __init__ (line 116) | def __init__(
    method enable_gradient_checkpointing (line 144) | def enable_gradient_checkpointing(self) -> None:
    method adding_cls_branch (line 147) | def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_di...
    method _convert_flow_pred_to_x0 (line 169) | def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch....
    method _convert_x0_to_flow_pred (line 196) | def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: tor...
    method forward (line 218) | def forward(
    method get_scheduler (line 294) | def get_scheduler(self) -> SchedulerInterface:
    method post_init (line 308) | def post_init(self):

FILE: wan/distributed/fsdp.py
  function shard_model (line 10) | def shard_model(

FILE: wan/distributed/xdit_context_parallel.py
  function pad_freqs (line 12) | def pad_freqs(original_tensor, target_len):
  function rope_apply (line 26) | def rope_apply(x, grid_sizes, freqs):
  function usp_dit_forward (line 66) | def usp_dit_forward(
  function usp_attn_forward (line 149) | def usp_attn_forward(self,

FILE: wan/image2video.py
  class WanI2V (line 29) | class WanI2V:
    method __init__ (line 31) | def __init__(
    method generate (line 129) | def generate(self,

FILE: wan/modules/attention.py
  function is_hopper_gpu (line 7) | def is_hopper_gpu():
  function flash_attention (line 32) | def flash_attention(
  function attention (line 139) | def attention(

FILE: wan/modules/causal_model.py
  function causal_rope_apply (line 31) | def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
  class CausalWanSelfAttention (line 62) | class CausalWanSelfAttention(nn.Module):
    method __init__ (line 64) | def __init__(self,
    method forward (line 90) | def forward(
  class CausalWanAttentionBlock (line 248) | class CausalWanAttentionBlock(nn.Module):
    method __init__ (line 250) | def __init__(self,
    method forward (line 288) | def forward(
  class CausalHead (line 343) | class CausalHead(nn.Module):
    method __init__ (line 345) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
    method forward (line 360) | def forward(self, x, e):
  class CausalWanModel (line 374) | class CausalWanModel(ModelMixin, ConfigMixin):
    method __init__ (line 386) | def __init__(self,
    method _set_gradient_checkpointing (line 507) | def _set_gradient_checkpointing(self, module, value=False):
    method _prepare_blockwise_causal_attn_mask (line 511) | def _prepare_blockwise_causal_attn_mask(
    method _prepare_teacher_forcing_mask (line 569) | def _prepare_teacher_forcing_mask(
    method _prepare_blockwise_causal_attn_mask_i2v (line 657) | def _prepare_blockwise_causal_attn_mask_i2v(
    method _forward_inference (line 717) | def _forward_inference(
    method _forward_train (line 848) | def _forward_train(
    method forward (line 1007) | def forward(
    method unpatchify (line 1018) | def unpatchify(self, x, grid_sizes):
    method init_weights (line 1043) | def init_weights(self):

FILE: wan/modules/clip.py
  function pos_interpolate (line 22) | def pos_interpolate(pos, seq_len):
  class QuickGELU (line 41) | class QuickGELU(nn.Module):
    method forward (line 43) | def forward(self, x):
  class LayerNorm (line 47) | class LayerNorm(nn.LayerNorm):
    method forward (line 49) | def forward(self, x):
  class SelfAttention (line 53) | class SelfAttention(nn.Module):
    method __init__ (line 55) | def __init__(self,
    method forward (line 74) | def forward(self, x):
  class SwiGLU (line 94) | class SwiGLU(nn.Module):
    method __init__ (line 96) | def __init__(self, dim, mid_dim):
    method forward (line 106) | def forward(self, x):
  class AttentionBlock (line 112) | class AttentionBlock(nn.Module):
    method __init__ (line 114) | def __init__(self,
    method forward (line 146) | def forward(self, x):
  class AttentionPool (line 156) | class AttentionPool(nn.Module):
    method __init__ (line 158) | def __init__(self,
    method forward (line 186) | def forward(self, x):
  class VisionTransformer (line 209) | class VisionTransformer(nn.Module):
    method __init__ (line 211) | def __init__(self,
    method forward (line 279) | def forward(self, x, interpolation=False, use_31_block=False):
  class XLMRobertaWithHead (line 303) | class XLMRobertaWithHead(XLMRoberta):
    method __init__ (line 305) | def __init__(self, **kwargs):
    method forward (line 315) | def forward(self, ids):
  class XLMRobertaCLIP (line 328) | class XLMRobertaCLIP(nn.Module):
    method __init__ (line 330) | def __init__(self,
    method forward (line 406) | def forward(self, imgs, txt_ids):
    method param_groups (line 418) | def param_groups(self):
  function _clip (line 434) | def _clip(pretrained=False,
  function clip_xlm_roberta_vit_h_14 (line 471) | def clip_xlm_roberta_vit_h_14(
  class CLIPModel (line 501) | class CLIPModel:
    method __init__ (line 503) | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
    method visual (line 527) | def visual(self, videos):

FILE: wan/modules/model.py
  function sinusoidal_embedding_1d (line 15) | def sinusoidal_embedding_1d(dim, position):
  function rope_params (line 29) | def rope_params(max_seq_len, dim, theta=10000):
  function rope_apply (line 40) | def rope_apply(x, grid_sizes, freqs):
  class WanRMSNorm (line 70) | class WanRMSNorm(nn.Module):
    method __init__ (line 72) | def __init__(self, dim, eps=1e-5):
    method forward (line 78) | def forward(self, x):
    method _norm (line 85) | def _norm(self, x):
  class WanLayerNorm (line 89) | class WanLayerNorm(nn.LayerNorm):
    method __init__ (line 91) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
    method forward (line 94) | def forward(self, x):
  class WanSelfAttention (line 102) | class WanSelfAttention(nn.Module):
    method __init__ (line 104) | def __init__(self,
    method forward (line 127) | def forward(self, x, seq_lens, grid_sizes, freqs):
  class WanT2VCrossAttention (line 159) | class WanT2VCrossAttention(WanSelfAttention):
    method forward (line 161) | def forward(self, x, context, context_lens, crossattn_cache=None):
  class WanGanCrossAttention (line 197) | class WanGanCrossAttention(WanSelfAttention):
    method forward (line 199) | def forward(self, x, context, crossattn_cache=None):
  class WanI2VCrossAttention (line 224) | class WanI2VCrossAttention(WanSelfAttention):
    method __init__ (line 226) | def __init__(self,
    method forward (line 240) | def forward(self, x, context, context_lens):
  class WanAttentionBlock (line 275) | class WanAttentionBlock(nn.Module):
    method __init__ (line 277) | def __init__(self,
    method forward (line 315) | def forward(
  class GanAttentionBlock (line 357) | class GanAttentionBlock(nn.Module):
    method __init__ (line 359) | def __init__(self,
    method forward (line 397) | def forward(
  class Head (line 439) | class Head(nn.Module):
    method __init__ (line 441) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
    method forward (line 456) | def forward(self, x, e):
  class MLPProj (line 469) | class MLPProj(torch.nn.Module):
    method __init__ (line 471) | def __init__(self, in_dim, out_dim):
    method forward (line 479) | def forward(self, image_embeds):
  class RegisterTokens (line 484) | class RegisterTokens(nn.Module):
    method __init__ (line 485) | def __init__(self, num_registers: int, dim: int):
    method forward (line 490) | def forward(self):
    method reset_parameters (line 493) | def reset_parameters(self):
  class WanModel (line 497) | class WanModel(ModelMixin, ConfigMixin):
    method __init__ (line 509) | def __init__(self,
    method _set_gradient_checkpointing (line 623) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 626) | def forward(
    method _forward (line 637) | def _forward(
    method _forward_classify (line 773) | def _forward_classify(
    method unpatchify (line 876) | def unpatchify(self, x, grid_sizes, c=None):
    method init_weights (line 901) | def init_weights(self):

FILE: wan/modules/t5.py
  function fp16_clamp (line 20) | def fp16_clamp(x):
  function init_weights (line 27) | def init_weights(m):
  class GELU (line 46) | class GELU(nn.Module):
    method forward (line 48) | def forward(self, x):
  class T5LayerNorm (line 53) | class T5LayerNorm(nn.Module):
    method __init__ (line 55) | def __init__(self, dim, eps=1e-6):
    method forward (line 61) | def forward(self, x):
  class T5Attention (line 69) | class T5Attention(nn.Module):
    method __init__ (line 71) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
    method forward (line 86) | def forward(self, x, context=None, mask=None, pos_bias=None):
  class T5FeedForward (line 123) | class T5FeedForward(nn.Module):
    method __init__ (line 125) | def __init__(self, dim, dim_ffn, dropout=0.1):
    method forward (line 136) | def forward(self, x):
  class T5SelfAttention (line 144) | class T5SelfAttention(nn.Module):
    method __init__ (line 146) | def __init__(self,
    method forward (line 170) | def forward(self, x, mask=None, pos_bias=None):
  class T5CrossAttention (line 178) | class T5CrossAttention(nn.Module):
    method __init__ (line 180) | def __init__(self,
    method forward (line 206) | def forward(self,
  class T5RelativeEmbedding (line 221) | class T5RelativeEmbedding(nn.Module):
    method __init__ (line 223) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
    method forward (line 233) | def forward(self, lq, lk):
    method _relative_position_bucket (line 245) | def _relative_position_bucket(self, rel_pos):
  class T5Encoder (line 267) | class T5Encoder(nn.Module):
    method __init__ (line 269) | def __init__(self,
    method forward (line 303) | def forward(self, ids, mask=None):
  class T5Decoder (line 315) | class T5Decoder(nn.Module):
    method __init__ (line 317) | def __init__(self,
    method forward (line 351) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
  class T5Model (line 372) | class T5Model(nn.Module):
    method __init__ (line 374) | def __init__(self,
    method forward (line 408) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
  function _t5 (line 415) | def _t5(name,
  function umt5_xxl (line 456) | def umt5_xxl(**kwargs):
  class T5EncoderModel (line 472) | class T5EncoderModel:
    method __init__ (line 474) | def __init__(
    method __call__ (line 506) | def __call__(self, texts, device):

FILE: wan/modules/tokenizers.py
  function basic_clean (line 12) | def basic_clean(text):
  function whitespace_clean (line 18) | def whitespace_clean(text):
  function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
  class HuggingfaceTokenizer (line 37) | class HuggingfaceTokenizer:
    method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
    method __call__ (line 49) | def __call__(self, sequence, **kwargs):
    method _clean (line 75) | def _clean(self, text):

FILE: wan/modules/vae.py
  class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
    method __init__ (line 22) | def __init__(self, *args, **kwargs):
    method forward (line 28) | def forward(self, x, cache_x=None):
  class RMS_norm (line 39) | class RMS_norm(nn.Module):
    method __init__ (line 41) | def __init__(self, dim, channel_first=True, images=True, bias=False):
    method forward (line 51) | def forward(self, x):
  class Upsample (line 57) | class Upsample(nn.Upsample):
    method forward (line 59) | def forward(self, x):
  class Resample (line 66) | class Resample(nn.Module):
    method __init__ (line 68) | def __init__(self, dim, mode):
    method forward (line 101) | def forward(self, x, feat_cache=None, feat_idx=[0]):
    method init_weight (line 162) | def init_weight(self, conv):
    method init_weight2 (line 174) | def init_weight2(self, conv):
  class ResidualBlock (line 186) | class ResidualBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, in_dim, out_dim, dropout=0.0):
    method forward (line 202) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class AttentionBlock (line 223) | class AttentionBlock(nn.Module):
    method __init__ (line 228) | def __init__(self, dim):
    method forward (line 240) | def forward(self, x):
  class Encoder3d (line 265) | class Encoder3d(nn.Module):
    method __init__ (line 267) | def __init__(self,
    method forward (line 318) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class Decoder3d (line 369) | class Decoder3d(nn.Module):
    method __init__ (line 371) | def __init__(self,
    method forward (line 423) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  function count_conv3d (line 475) | def count_conv3d(model):
  class WanVAE_ (line 483) | class WanVAE_(nn.Module):
    method __init__ (line 485) | def __init__(self,
    method forward (line 511) | def forward(self, x):
    method encode (line 517) | def encode(self, x, scale):
    method decode (line 545) | def decode(self, z, scale):
    method cached_decode (line 571) | def cached_decode(self, z, scale):
    method sample (line 595) | def sample(self, imgs, deterministic=False):
    method clear_cache (line 602) | def clear_cache(self):
  function _video_vae (line 612) | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
  class WanVAE (line 639) | class WanVAE:
    method __init__ (line 641) | def __init__(self,
    method encode (line 667) | def encode(self, videos):
    method decode (line 677) | def decode(self, zs):

FILE: wan/modules/xlm_roberta.py
  class SelfAttention (line 10) | class SelfAttention(nn.Module):
    method __init__ (line 12) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
    method forward (line 27) | def forward(self, x, mask):
  class AttentionBlock (line 49) | class AttentionBlock(nn.Module):
    method __init__ (line 51) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
    method forward (line 66) | def forward(self, x, mask):
  class XLMRoberta (line 76) | class XLMRoberta(nn.Module):
    method __init__ (line 81) | def __init__(self,
    method forward (line 118) | def forward(self, ids):
  function xlm_roberta_large (line 146) | def xlm_roberta_large(pretrained=False,

FILE: wan/text2video.py
  class WanT2V (line 26) | class WanT2V:
    method __init__ (line 28) | def __init__(
    method generate (line 110) | def generate(self,

FILE: wan/utils/fm_solvers.py
  function get_sampling_sigmas (line 22) | def get_sampling_sigmas(sampling_steps, shift):
  function retrieve_timesteps (line 29) | def retrieve_timesteps(
  class FlowDPMSolverMultistepScheduler (line 69) | class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 129) | def __init__(
    method step_index (line 202) | def step_index(self):
    method begin_index (line 209) | def begin_index(self):
    method set_begin_index (line 216) | def set_begin_index(self, begin_index: int = 0):
    method set_timesteps (line 226) | def set_timesteps(
    method _threshold_sample (line 292) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
    method _sigma_to_t (line 330) | def _sigma_to_t(self, sigma):
    method _sigma_to_alpha_sigma_t (line 333) | def _sigma_to_alpha_sigma_t(self, sigma):
    method time_shift (line 337) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
    method convert_model_output (line 341) | def convert_model_output(
    method dpm_solver_first_order_update (line 415) | def dpm_solver_first_order_update(
    method multistep_dpm_solver_second_order_update (line 486) | def multistep_dpm_solver_second_order_update(
    method multistep_dpm_solver_third_order_update (line 596) | def multistep_dpm_solver_third_order_update(
    method index_for_timestep (line 679) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 693) | def _init_step_index(self, timestep):
    method step (line 706) | def step(
    method scale_model_input (line 800) | def scale_model_input(self, sample: torch.Tensor, *args,
    method add_noise (line 815) | def add_noise(
    method __len__ (line 856) | def __len__(self):

FILE: wan/utils/fm_solvers_unipc.py
  class FlowUniPCMultistepScheduler (line 20) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 77) | def __init__(
    method step_index (line 135) | def step_index(self):
    method begin_index (line 142) | def begin_index(self):
    method set_begin_index (line 149) | def set_begin_index(self, begin_index: int = 0):
    method set_timesteps (line 160) | def set_timesteps(
    method _threshold_sample (line 230) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
    method _sigma_to_t (line 269) | def _sigma_to_t(self, sigma):
    method _sigma_to_alpha_sigma_t (line 272) | def _sigma_to_alpha_sigma_t(self, sigma):
    method time_shift (line 276) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
    method convert_model_output (line 279) | def convert_model_output(
    method multistep_uni_p_bh_update (line 350) | def multistep_uni_p_bh_update(
    method multistep_uni_c_bh_update (line 486) | def multistep_uni_c_bh_update(
    method index_for_timestep (line 628) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method _init_step_index (line 643) | def _init_step_index(self, timestep):
    method step (line 655) | def step(self,
    method scale_model_input (line 741) | def scale_model_input(self, sample: torch.Tensor, *args,
    method add_noise (line 758) | def add_noise(
    method __len__ (line 799) | def __len__(self):

FILE: wan/utils/prompt_extend.py
  class PromptOutput (line 101) | class PromptOutput(object):
    method add_custom_field (line 108) | def add_custom_field(self, key: str, value) -> None:
  class PromptExpander (line 112) | class PromptExpander:
    method __init__ (line 114) | def __init__(self, model_name, is_vl=False, device=0, **kwargs):
    method extend_with_img (line 119) | def extend_with_img(self,
    method extend (line 128) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method decide_system_prompt (line 131) | def decide_system_prompt(self, tar_lang="ch"):
    method __call__ (line 138) | def __call__(self,
  class DashScopePromptExpander (line 157) | class DashScopePromptExpander(PromptExpander):
    method __init__ (line 159) | def __init__(self,
    method extend (line 196) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method extend_with_img (line 232) | def extend_with_img(self,
  class QwenPromptExpander (line 300) | class QwenPromptExpander(PromptExpander):
    method __init__ (line 309) | def __init__(self, model_name=None, device=0, is_vl=False, **kwargs):
    method extend (line 366) | def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
    method extend_with_img (line 397) | def extend_with_img(self,

FILE: wan/utils/qwen_vl_utils.py
  function round_by_factor (line 39) | def round_by_factor(number: int, factor: int) -> int:
  function ceil_by_factor (line 44) | def ceil_by_factor(number: int, factor: int) -> int:
  function floor_by_factor (line 49) | def floor_by_factor(number: int, factor: int) -> int:
  function smart_resize (line 54) | def smart_resize(height: int,
  function fetch_image (line 85) | def fetch_image(ele: dict[str, str | Image.Image],
  function smart_nframes (line 133) | def smart_nframes(
  function _read_video_torchvision (line 177) | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
  function is_decord_available (line 215) | def is_decord_available() -> bool:
  function _read_video_decord (line 221) | def _read_video_decord(ele: dict,) -> torch.Tensor:
  function get_video_reader_backend (line 261) | def get_video_reader_backend() -> str:
  function fetch_video (line 274) | def fetch_video(
  function extract_vision_info (line 328) | def extract_vision_info(
  function process_vision_info (line 344) | def process_vision_info(

FILE: wan/utils/utils.py
  function rand_name (line 14) | def rand_name(length=8, suffix=''):
  function cache_video (line 23) | def cache_video(tensor,
  function cache_image (line 64) | def cache_image(tensor,
  function str2bool (line 94) | def str2bool(v):
Condensed preview — 154 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,501K chars).
[
  {
    "path": ".gitignore",
    "chars": 104,
    "preview": "__pycache__\n*.egg-info\n\nwan_models\ncheckpoints\noutput\ndataset\nprompts/vidprom_filtered_extended.txt\nlogs"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 21236,
    "preview": "<div align=\"center\">\n\n## Causal Forcing & Causal Forcing++\n### Autoregressive Diffusion Distillation Done Right for High"
  },
  {
    "path": "configs/ar_diffusion_tf_chunkwise.yaml",
    "chars": 1113,
    "preview": "generator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: size\nfake_score_fsdp_wrap_strategy: size\nreal_name: Wa"
  },
  {
    "path": "configs/ar_diffusion_tf_framewise.yaml",
    "chars": 1113,
    "preview": "generator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: size\nfake_score_fsdp_wrap_strategy: size\nreal_name: Wa"
  },
  {
    "path": "configs/causal_cd_chunkwise.yaml",
    "chars": 1364,
    "preview": "generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: "
  },
  {
    "path": "configs/causal_cd_framewise.yaml",
    "chars": 1364,
    "preview": "generator_ckpt: checkpoints/framewise/ar_diffusion.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy: "
  },
  {
    "path": "configs/causal_forcing_dmd_chunkwise.yaml",
    "chars": 1315,
    "preview": "generator_ckpt: checkpoints/chunkwise/causal_ode.pt # 🔥 or checkpoints/chunkwise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
  },
  {
    "path": "configs/causal_forcing_dmd_framewise.yaml",
    "chars": 1314,
    "preview": "generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
  },
  {
    "path": "configs/causal_forcing_dmd_framewise_1step.yaml",
    "chars": 1522,
    "preview": "generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
  },
  {
    "path": "configs/causal_forcing_dmd_framewise_2step.yaml",
    "chars": 1528,
    "preview": "generator_ckpt: checkpoints/framewise/causal_ode.pt # 🔥 or checkpoints/framewise/causal_cd.pt\ngenerator_fsdp_wrap_strate"
  },
  {
    "path": "configs/causal_ode_chunkwise.yaml",
    "chars": 655,
    "preview": "generator_ckpt: checkpoints/chunkwise/ar_diffusion.pt\ngenerator_grad:\n  model: true\ndenoising_step_list:\n- 1000\n- 750\n- "
  },
  {
    "path": "configs/causal_ode_framewise.yaml",
    "chars": 657,
    "preview": "generator_ckpt: checkpoints/framewise/ar_diffusion.pt\ngenerator_grad:\n  model: true\ndenoising_step_list:\n- 1000\n- 750\n- "
  },
  {
    "path": "configs/default_config.yaml",
    "chars": 403,
    "preview": "independent_first_frame: false\nwarp_denoising_step: false\nweight_decay: 0.01\nsame_step_across_blocks: true\ndiscriminator"
  },
  {
    "path": "demo.py",
    "chars": 24414,
    "preview": "\"\"\"\nDemo for Causal-Forcing.\n\"\"\"\n\nimport os\nimport re\nimport random\nimport time\nimport base64\nimport argparse\nimport has"
  },
  {
    "path": "demo_utils/constant.py",
    "chars": 1352,
    "preview": "\nimport torch\n\n\nZERO_VAE_CACHE = [\n    torch.zeros(1, 16, 2, 60, 104),\n    torch.zeros(1, 384, 2, 60, 104),\n    torch.ze"
  },
  {
    "path": "demo_utils/memory.py",
    "chars": 4417,
    "preview": "# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils\n# Apache-2.0 License\n# By lllyasviel\n\nimport "
  },
  {
    "path": "demo_utils/taehv.py",
    "chars": 14157,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nTiny AutoEncoder for Hunyuan Video\n(DNN for encoding / decoding videos to Hunyuan Video's lat"
  },
  {
    "path": "demo_utils/utils.py",
    "chars": 17547,
    "preview": "# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils\n# Apache-2.0 License\n# By lllyasviel\n\nimport "
  },
  {
    "path": "demo_utils/vae.py",
    "chars": 15414,
    "preview": "from typing import List\nfrom einops import rearrange\nimport tensorrt as trt\nimport torch\nimport torch.nn as nn\n\nfrom dem"
  },
  {
    "path": "demo_utils/vae_block3.py",
    "chars": 11058,
    "preview": "from typing import List\nfrom einops import rearrange\nimport torch\nimport torch.nn as nn\n\nfrom wan.modules.vae import Att"
  },
  {
    "path": "demo_utils/vae_torch2trt.py",
    "chars": 11889,
    "preview": "# ---- INT8 (optional) ----\nfrom demo_utils.vae import (\n    VAEDecoderWrapperSingle,                         # main nn."
  },
  {
    "path": "get_causal_ode_data_chunkwise.py",
    "chars": 4371,
    "preview": "from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\nfrom utils.scheduler import FlowMatchSc"
  },
  {
    "path": "get_causal_ode_data_framewise.py",
    "chars": 4371,
    "preview": "from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\nfrom utils.scheduler import FlowMatchSc"
  },
  {
    "path": "get_causal_ode_data_kv_optimized.py",
    "chars": 5832,
    "preview": "import argparse\nimport math\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom tqdm import tqdm\n\nfrom utils.d"
  },
  {
    "path": "inference.py",
    "chars": 8317,
    "preview": "import argparse\nimport argparse\nimport torch\nimport os\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom torchv"
  },
  {
    "path": "long_video/LICENSE",
    "chars": 12242,
    "preview": "Tencent is pleased to support the community by making RollingForcing available.\n\nCopyright (C)  2025 Tencent.  All right"
  },
  {
    "path": "long_video/README.md",
    "chars": 1279,
    "preview": "Builing on [Rolling Forcing](https://github.com/TencentARC/RollingForcing), we implemented minute-level long video gener"
  },
  {
    "path": "long_video/app.py",
    "chars": 6831,
    "preview": "import os\nimport argparse\nimport time\nfrom typing import Optional\n\nimport torch\nfrom torchvision.io import write_video\nf"
  },
  {
    "path": "long_video/configs/default_config.yaml",
    "chars": 403,
    "preview": "independent_first_frame: false\nwarp_denoising_step: false\nweight_decay: 0.01\nsame_step_across_blocks: true\ndiscriminator"
  },
  {
    "path": "long_video/configs/rolling_forcing_dmd.yaml",
    "chars": 1168,
    "preview": "generator_ckpt: ../checkpoints/chunkwise/causal_ode.pt\ngenerator_fsdp_wrap_strategy: size\nreal_score_fsdp_wrap_strategy:"
  },
  {
    "path": "long_video/inference.py",
    "chars": 7755,
    "preview": "import argparse\nimport torch\nimport os\nfrom omegaconf import OmegaConf\nfrom collections import OrderedDict\nfrom tqdm imp"
  },
  {
    "path": "long_video/model/__init__.py",
    "chars": 278,
    "preview": "from .diffusion import CausalDiffusion\nfrom .causvid import CausVid\nfrom .dmd import DMD\nfrom .gan import GAN\nfrom .sid "
  },
  {
    "path": "long_video/model/base.py",
    "chars": 11031,
    "preview": "from typing import Tuple\nfrom einops import rearrange\nfrom torch import nn\nimport torch.distributed as dist\nimport torch"
  },
  {
    "path": "long_video/model/causvid.py",
    "chars": 17252,
    "preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\n\n\nclass CausVid("
  },
  {
    "path": "long_video/model/diffusion.py",
    "chars": 5641,
    "preview": "from typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_wrapper import WanDiffusionWrappe"
  },
  {
    "path": "long_video/model/dmd.py",
    "chars": 15496,
    "preview": "from pipeline import RollingForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\ni"
  },
  {
    "path": "long_video/model/gan.py",
    "chars": 14244,
    "preview": "import copy\nfrom pipeline import RollingForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Tuple"
  },
  {
    "path": "long_video/model/ode_regression.py",
    "chars": 5763,
    "preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_w"
  },
  {
    "path": "long_video/model/sid.py",
    "chars": 12650,
    "preview": "from pipeline import RollingForcingTrainingPipeline\nfrom typing import Optional, Tuple\nimport torch\n\nfrom model.base imp"
  },
  {
    "path": "long_video/pipeline/__init__.py",
    "chars": 568,
    "preview": "from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline\nfrom .bidirectional_inference imp"
  },
  {
    "path": "long_video/pipeline/bidirectional_diffusion_inference.py",
    "chars": 4146,
    "preview": "from tqdm import tqdm\nfrom typing import List\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMultistepSched"
  },
  {
    "path": "long_video/pipeline/bidirectional_inference.py",
    "chars": 3109,
    "preview": "from typing import List\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\n\n"
  },
  {
    "path": "long_video/pipeline/causal_diffusion_inference.py",
    "chars": 16465,
    "preview": "from tqdm import tqdm\nfrom typing import List, Optional\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMult"
  },
  {
    "path": "long_video/pipeline/rolling_forcing_inference.py",
    "chars": 17659,
    "preview": "from typing import List, Optional\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVA"
  },
  {
    "path": "long_video/pipeline/rolling_forcing_training.py",
    "chars": 22725,
    "preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
  },
  {
    "path": "long_video/prompts/example_prompts.txt",
    "chars": 6517,
    "preview": "A cinematic scene from a classic western movie, featuring a rugged man riding a powerful horse through the vast Gobi Des"
  },
  {
    "path": "long_video/requirements.txt",
    "chars": 551,
    "preview": "torch==2.5.1\ntorchvision==0.20.1\ntorchaudio==2.5.1\nopencv-python>=4.9.0.80\ndiffusers==0.31.0\ntransformers>=4.49.0\ntokeni"
  },
  {
    "path": "long_video/train.py",
    "chars": 1612,
    "preview": "import argparse\nimport os\nfrom omegaconf import OmegaConf\n\nfrom trainer import DiffusionTrainer, GANTrainer, ODETrainer,"
  },
  {
    "path": "long_video/trainer/__init__.py",
    "chars": 297,
    "preview": "from .diffusion import Trainer as DiffusionTrainer\nfrom .gan import Trainer as GANTrainer\nfrom .ode import Trainer as OD"
  },
  {
    "path": "long_video/trainer/diffusion.py",
    "chars": 10384,
    "preview": "import gc\nimport logging\n\nfrom model import CausalDiffusion\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom ut"
  },
  {
    "path": "long_video/trainer/distillation.py",
    "chars": 15969,
    "preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.dataset import TextDataset\nfro"
  },
  {
    "path": "long_video/trainer/gan.py",
    "chars": 19999,
    "preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.distributed import EMA_FSDP, f"
  },
  {
    "path": "long_video/trainer/ode.py",
    "chars": 9791,
    "preview": "import gc\nimport logging\nfrom utils.dataset import ODERegressionLMDBDataset, cycle\nfrom model import ODERegression\nfrom "
  },
  {
    "path": "long_video/utils/dataset.py",
    "chars": 7346,
    "preview": "from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb\nfrom torch.utils.data import Dataset\nimport num"
  },
  {
    "path": "long_video/utils/distributed.py",
    "chars": 4571,
    "preview": "from datetime import timedelta\nfrom functools import partial\nimport os\nimport torch\nimport torch.distributed as dist\nfro"
  },
  {
    "path": "long_video/utils/lmdb.py",
    "chars": 2045,
    "preview": "import numpy as np\n\n\ndef get_array_shape_from_lmdb(env, array_name):\n    with env.begin() as txn:\n        image_shape = "
  },
  {
    "path": "long_video/utils/loss.py",
    "chars": 2467,
    "preview": "from abc import ABC, abstractmethod\nimport torch\n\n\nclass DenoisingLoss(ABC):\n    @abstractmethod\n    def __call__(\n     "
  },
  {
    "path": "long_video/utils/misc.py",
    "chars": 1155,
    "preview": "import numpy as np\nimport random\nimport torch\n\n\ndef set_seed(seed: int, deterministic: bool = False):\n    \"\"\"\n    Helper"
  },
  {
    "path": "long_video/utils/scheduler.py",
    "chars": 7979,
    "preview": "from abc import abstractmethod, ABC\nimport torch\n\n\nclass SchedulerInterface(ABC):\n    \"\"\"\n    Base class for diffusion n"
  },
  {
    "path": "long_video/utils/wan_wrapper.py",
    "chars": 12465,
    "preview": "import types\nfrom typing import List, Optional\nimport torch\nfrom torch import nn\n\nfrom utils.scheduler import SchedulerI"
  },
  {
    "path": "long_video/wan/README.md",
    "chars": 92,
    "preview": "Code in this folder is modified from https://github.com/Wan-Video/Wan2.1\nApache-2.0 License "
  },
  {
    "path": "long_video/wan/__init__.py",
    "chars": 107,
    "preview": "from . import configs, distributed, modules\nfrom .image2video import WanI2V\nfrom .text2video import WanT2V\n"
  },
  {
    "path": "long_video/wan/configs/__init__.py",
    "chars": 1011,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .wan_t2v_14B import t2v_14B\nfrom .wan_t2v_"
  },
  {
    "path": "long_video/wan/configs/shared_config.py",
    "chars": 650,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n# -"
  },
  {
    "path": "long_video/wan/configs/wan_i2v_14B.py",
    "chars": 972,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
  },
  {
    "path": "long_video/wan/configs/wan_t2v_14B.py",
    "chars": 743,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
  },
  {
    "path": "long_video/wan/configs/wan_t2v_1_3B.py",
    "chars": 760,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
  },
  {
    "path": "long_video/wan/distributed/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "long_video/wan/distributed/fsdp.py",
    "chars": 1077,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom functools import partial\n\nimport torch\nfro"
  },
  {
    "path": "long_video/wan/distributed/xdit_context_parallel.py",
    "chars": 5899,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
  },
  {
    "path": "long_video/wan/image2video.py",
    "chars": 13203,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
  },
  {
    "path": "long_video/wan/modules/__init__.py",
    "chars": 365,
    "preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
  },
  {
    "path": "long_video/wan/modules/attention.py",
    "chars": 5641,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n    import flash_attn_interf"
  },
  {
    "path": "long_video/wan/modules/causal_model.py",
    "chars": 45180,
    "preview": "from wan.modules.attention import attention\nfrom wan.modules.model import (\n    WanRMSNorm,\n    rope_apply,\n    WanLayer"
  },
  {
    "path": "long_video/wan/modules/clip.py",
    "chars": 16835,
    "preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
  },
  {
    "path": "long_video/wan/modules/model.py",
    "chars": 30757,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn"
  },
  {
    "path": "long_video/wan/modules/t5.py",
    "chars": 16910,
    "preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
  },
  {
    "path": "long_video/wan/modules/tokenizers.py",
    "chars": 2431,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
  },
  {
    "path": "long_video/wan/modules/vae.py",
    "chars": 23735,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
  },
  {
    "path": "long_video/wan/modules/xlm_roberta.py",
    "chars": 4865,
    "preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
  },
  {
    "path": "long_video/wan/text2video.py",
    "chars": 10241,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
  },
  {
    "path": "long_video/wan/utils/__init__.py",
    "chars": 339,
    "preview": "from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,\n                         retrieve_timeste"
  },
  {
    "path": "long_video/wan/utils/fm_solvers.py",
    "chars": 40232,
    "preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
  },
  {
    "path": "long_video/wan/utils/fm_solvers_unipc.py",
    "chars": 32645,
    "preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
  },
  {
    "path": "long_video/wan/utils/prompt_extend.py",
    "chars": 30344,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
  },
  {
    "path": "long_video/wan/utils/qwen_vl_utils.py",
    "chars": 13044,
    "preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
  },
  {
    "path": "long_video/wan/utils/utils.py",
    "chars": 3239,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
  },
  {
    "path": "model/__init__.py",
    "chars": 351,
    "preview": "from .diffusion import CausalDiffusion\nfrom .causvid import CausVid\nfrom .dmd import DMD\nfrom .gan import GAN\nfrom .sid "
  },
  {
    "path": "model/base.py",
    "chars": 26975,
    "preview": "from typing import Tuple\nfrom einops import rearrange\nfrom torch import nn\nimport torch.distributed as dist\nimport torch"
  },
  {
    "path": "model/causvid.py",
    "chars": 17342,
    "preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\n\n\nclass CausVid("
  },
  {
    "path": "model/diffusion.py",
    "chars": 5814,
    "preview": "from typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_wrapper import WanDiffusionWrappe"
  },
  {
    "path": "model/dmd.py",
    "chars": 17367,
    "preview": "from pipeline import SelfForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\nimpo"
  },
  {
    "path": "model/gan.py",
    "chars": 14232,
    "preview": "import copy\nfrom pipeline import SelfForcingTrainingPipeline\nimport torch.nn.functional as F\nfrom typing import Tuple\nim"
  },
  {
    "path": "model/naive_consistency.py",
    "chars": 6240,
    "preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\nimport random\nfrom model.base import BaseModel\nfro"
  },
  {
    "path": "model/ode_regression.py",
    "chars": 5763,
    "preview": "import torch.nn.functional as F\nfrom typing import Tuple\nimport torch\n\nfrom model.base import BaseModel\nfrom utils.wan_w"
  },
  {
    "path": "model/sid.py",
    "chars": 12638,
    "preview": "from pipeline import SelfForcingTrainingPipeline\nfrom typing import Optional, Tuple\nimport torch\n\nfrom model.base import"
  },
  {
    "path": "pipeline/__init__.py",
    "chars": 759,
    "preview": "from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline\nfrom .bidirectional_inference imp"
  },
  {
    "path": "pipeline/bidirectional_diffusion_inference.py",
    "chars": 4146,
    "preview": "from tqdm import tqdm\nfrom typing import List\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMultistepSched"
  },
  {
    "path": "pipeline/bidirectional_inference.py",
    "chars": 3094,
    "preview": "from typing import List\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper\n\n"
  },
  {
    "path": "pipeline/bidirectional_training.py",
    "chars": 7801,
    "preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
  },
  {
    "path": "pipeline/causal_diffusion_inference.py",
    "chars": 32056,
    "preview": "from tqdm import tqdm\nfrom typing import List, Optional\nimport torch\n\nfrom wan.utils.fm_solvers import FlowDPMSolverMult"
  },
  {
    "path": "pipeline/causal_inference.py",
    "chars": 17437,
    "preview": "from typing import List, Optional\nimport time\nimport torch\n\nfrom utils.wan_wrapper import WanDiffusionWrapper, WanTextEn"
  },
  {
    "path": "pipeline/self_forcing_training.py",
    "chars": 16256,
    "preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
  },
  {
    "path": "pipeline/teacher_forcing_training.py",
    "chars": 10927,
    "preview": "from utils.wan_wrapper import WanDiffusionWrapper\nfrom utils.scheduler import SchedulerInterface\nfrom typing import List"
  },
  {
    "path": "prompts/demos.txt",
    "chars": 50302,
    "preview": "Across a snowfield under pink dawn, a skier in a crimson down jacket, fur-lined hood, and mirrored goggles dominates the"
  },
  {
    "path": "prompts/i2v/target_crop_info_26-15.json",
    "chars": 870,
    "preview": "[\n  {\n    \"file_name\": \"000001.png\",\n    \"caption\": \"A cinematic closeup and detailed portrait of a reindeer standing in"
  },
  {
    "path": "requirements.txt",
    "chars": 445,
    "preview": "torch>=2.4.0\ntorchvision>=0.19.0\nopencv-python>=4.9.0.80\ndiffusers==0.31.0\ntransformers>=4.49.0\ntokenizers>=0.20.3\naccel"
  },
  {
    "path": "setup.py",
    "chars": 131,
    "preview": "from setuptools import setup, find_packages\nsetup(\n    name=\"causal_forcing\",\n    version=\"0.0.2\",\n    packages=find_pac"
  },
  {
    "path": "train.py",
    "chars": 1768,
    "preview": "import argparse\nimport os\nfrom omegaconf import OmegaConf\nimport wandb\n\nfrom trainer import DiffusionTrainer, ODETrainer"
  },
  {
    "path": "trainer/__init__.py",
    "chars": 399,
    "preview": "from .diffusion import Trainer as DiffusionTrainer\nfrom .gan import Trainer as GANTrainer\nfrom .ode import Trainer as OD"
  },
  {
    "path": "trainer/diffusion.py",
    "chars": 11454,
    "preview": "import gc\nimport logging\n\nfrom model import CausalDiffusion\nfrom utils.dataset import cycle, LatentLMDBDataset\nfrom util"
  },
  {
    "path": "trainer/distillation.py",
    "chars": 15443,
    "preview": "import gc\nimport logging\nfrom utils.dataset import cycle\nfrom utils.dataset import TextDataset\nfrom utils.distributed im"
  },
  {
    "path": "trainer/gan.py",
    "chars": 20024,
    "preview": "import gc\nimport logging\n\nfrom utils.dataset import ShardingLMDBDataset, cycle\nfrom utils.distributed import EMA_FSDP, f"
  },
  {
    "path": "trainer/naive_cd.py",
    "chars": 11578,
    "preview": "import gc\nimport logging\nfrom utils.dataset import cycle\nfrom utils.dataset import LatentLMDBDataset\nfrom utils.distribu"
  },
  {
    "path": "trainer/ode.py",
    "chars": 9204,
    "preview": "import gc\nimport logging\nfrom utils.dataset import ODERegressionLMDBDataset, cycle\nfrom model import ODERegression\nfrom "
  },
  {
    "path": "utils/create_lmdb_iterative.py",
    "chars": 3693,
    "preview": "from tqdm import tqdm\nimport numpy as np\nimport argparse\nimport torch\nimport lmdb\nimport glob\nimport os\n\n\ndef store_arra"
  },
  {
    "path": "utils/dataset.py",
    "chars": 8497,
    "preview": "from utils.lmdb_ import get_array_shape_from_lmdb, retrieve_row_from_lmdb\nfrom torch.utils.data import Dataset\nimport nu"
  },
  {
    "path": "utils/distributed.py",
    "chars": 4994,
    "preview": "from datetime import timedelta\nfrom functools import partial\nimport os\nimport torch\nimport torch.distributed as dist\nfro"
  },
  {
    "path": "utils/lmdb_.py",
    "chars": 2045,
    "preview": "import numpy as np\n\n\ndef get_array_shape_from_lmdb(env, array_name):\n    with env.begin() as txn:\n        image_shape = "
  },
  {
    "path": "utils/loss.py",
    "chars": 2467,
    "preview": "from abc import ABC, abstractmethod\nimport torch\n\n\nclass DenoisingLoss(ABC):\n    @abstractmethod\n    def __call__(\n     "
  },
  {
    "path": "utils/merge_and_get_clean.py",
    "chars": 6069,
    "preview": "import os, shutil, lmdb, numpy as np\nfrom tqdm import tqdm\n\nBASE = \"dataset\"\nBATCH = 512\nMAP_MULT = 2.2\n\ndef read_shape("
  },
  {
    "path": "utils/merge_lmdb.py",
    "chars": 5973,
    "preview": "import os, shutil, lmdb, numpy as np\nfrom tqdm import tqdm\n\nBASE = \"dataset\"\nBATCH = 512\nMAP_MULT = 2.2\n\ndef read_shape("
  },
  {
    "path": "utils/misc.py",
    "chars": 1155,
    "preview": "import numpy as np\nimport random\nimport torch\n\n\ndef set_seed(seed: int, deterministic: bool = False):\n    \"\"\"\n    Helper"
  },
  {
    "path": "utils/ode_generation.py",
    "chars": 11719,
    "preview": "from typing import Dict, Iterable, Optional\n\nimport torch\n\n\ndef merge_cfg_prompt_embeds(\n    conditional_dict: dict,\n   "
  },
  {
    "path": "utils/scheduler.py",
    "chars": 7979,
    "preview": "from abc import abstractmethod, ABC\nimport torch\n\n\nclass SchedulerInterface(ABC):\n    \"\"\"\n    Base class for diffusion n"
  },
  {
    "path": "utils/wan_wrapper.py",
    "chars": 12591,
    "preview": "import types\nfrom typing import List, Optional\nimport torch\nfrom torch import nn\n\nfrom utils.scheduler import SchedulerI"
  },
  {
    "path": "wan/README.md",
    "chars": 92,
    "preview": "Code in this folder is modified from https://github.com/Wan-Video/Wan2.1\nApache-2.0 License "
  },
  {
    "path": "wan/__init__.py",
    "chars": 107,
    "preview": "from . import configs, distributed, modules\nfrom .image2video import WanI2V\nfrom .text2video import WanT2V\n"
  },
  {
    "path": "wan/configs/__init__.py",
    "chars": 1011,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom .wan_t2v_14B import t2v_14B\nfrom .wan_t2v_"
  },
  {
    "path": "wan/configs/shared_config.py",
    "chars": 650,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\n# -"
  },
  {
    "path": "wan/configs/wan_i2v_14B.py",
    "chars": 972,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nfrom easydict import EasyDict\n\nfro"
  },
  {
    "path": "wan/configs/wan_t2v_14B.py",
    "chars": 743,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
  },
  {
    "path": "wan/configs/wan_t2v_1_3B.py",
    "chars": 760,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom easydict import EasyDict\n\nfrom .shared_con"
  },
  {
    "path": "wan/distributed/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "wan/distributed/fsdp.py",
    "chars": 1077,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nfrom functools import partial\n\nimport torch\nfro"
  },
  {
    "path": "wan/distributed/xdit_context_parallel.py",
    "chars": 5899,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\nimport torch.cuda.amp as amp\nfrom "
  },
  {
    "path": "wan/image2video.py",
    "chars": 13203,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
  },
  {
    "path": "wan/modules/__init__.py",
    "chars": 365,
    "preview": "from .attention import flash_attention\nfrom .model import WanModel\nfrom .t5 import T5Decoder, T5Encoder, T5EncoderModel,"
  },
  {
    "path": "wan/modules/attention.py",
    "chars": 5641,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n    import flash_attn_interf"
  },
  {
    "path": "wan/modules/causal_model.py",
    "chars": 42290,
    "preview": "from wan.modules.attention import attention\nfrom wan.modules.model import (\n    WanRMSNorm,\n    rope_apply,\n    WanLayer"
  },
  {
    "path": "wan/modules/clip.py",
    "chars": 16835,
    "preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
  },
  {
    "path": "wan/modules/model.py",
    "chars": 30757,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\n\nimport torch\nimport torch.nn as nn"
  },
  {
    "path": "wan/modules/t5.py",
    "chars": 16910,
    "preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
  },
  {
    "path": "wan/modules/tokenizers.py",
    "chars": 2431,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
  },
  {
    "path": "wan/modules/vae.py",
    "chars": 23735,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.cuda."
  },
  {
    "path": "wan/modules/xlm_roberta.py",
    "chars": 4865,
    "preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
  },
  {
    "path": "wan/text2video.py",
    "chars": 10241,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport gc\nimport logging\nimport math\nimport os\n"
  },
  {
    "path": "wan/utils/__init__.py",
    "chars": 339,
    "preview": "from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,\n                         retrieve_timeste"
  },
  {
    "path": "wan/utils/fm_solvers.py",
    "chars": 40232,
    "preview": "# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep"
  },
  {
    "path": "wan/utils/fm_solvers_unipc.py",
    "chars": 32645,
    "preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
  },
  {
    "path": "wan/utils/prompt_extend.py",
    "chars": 30344,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport json\nimport math\nimport os\nimport random"
  },
  {
    "path": "wan/utils/qwen_vl_utils.py",
    "chars": 13044,
    "preview": "# Copied from https://github.com/kq-chen/qwen-vl-utils\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights re"
  },
  {
    "path": "wan/utils/utils.py",
    "chars": 3239,
    "preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport argparse\nimport binascii\nimport os\nimpor"
  }
]

About this extraction

This page contains the full source code of the thu-ml/Causal-Forcing GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 154 files (1.4 MB), approximately 339.4k tokens, and a symbol index with 1104 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!