Showing preview only (302K chars total). Download the full file or copy to clipboard to get everything.
Repository: SkyworkAI/SkyReels-V2
Branch: main
Commit: 9351d1315220
Files: 36
Total size: 288.8 KB
Directory structure:
gitextract_4tjgy5ni/
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── README.md
├── generate_video.py
├── generate_video_df.py
├── requirements.txt
├── skycaptioner_v1/
│ ├── README.md
│ ├── examples/
│ │ ├── test.csv
│ │ └── test_result.csv
│ ├── infer_fusion_caption.sh
│ ├── infer_struct_caption.sh
│ ├── requirements.txt
│ └── scripts/
│ ├── gradio_fusion_caption.py
│ ├── gradio_struct_caption.py
│ ├── utils.py
│ ├── vllm_fusion_caption.py
│ └── vllm_struct_caption.py
└── skyreels_v2_infer/
├── __init__.py
├── distributed/
│ ├── __init__.py
│ └── xdit_context_parallel.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── transformer.py
│ ├── vae.py
│ └── xlm_roberta.py
├── pipelines/
│ ├── __init__.py
│ ├── diffusion_forcing_pipeline.py
│ ├── image2video_pipeline.py
│ ├── prompt_enhancer.py
│ └── text2video_pipeline.py
└── scheduler/
├── __init__.py
└── fm_solvers_unipc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
checkpoint/*
checkpoint
results/*
.DS_Store
results/*
*.png
*.jpg
*.mp4
*.log*
*.json
scripts/transformer/*
compile_cache
scripts/.gradio/*
*.pkl
# *.csv
*.jsonl
out/*
model/
run.sh
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/asottile/reorder-python-imports.git
rev: v3.8.3
hooks:
- id: reorder-python-imports
name: Reorder Python imports
types: [file, python]
- repo: https://github.com/psf/black.git
rev: 22.8.0
hooks:
- id: black
additional_dependencies: ['click==8.0.4']
args: [--line-length=120]
types: [file, python]
- repo: https://github.com/pre-commit/pre-commit-hooks.git
rev: v4.3.0
hooks:
- id: check-byte-order-marker
types: [file, python]
- id: trailing-whitespace
types: [file, python]
- id: end-of-file-fixer
types: [file, python]
================================================
FILE: LICENSE.txt
================================================
---
language:
- en
- zh
license: other
tasks:
- text-generation
---
<!-- markdownlint-disable first-line-h1 -->
<!-- markdownlint-disable html -->
# <span id="Terms">声明与协议/Terms and Conditions</span>
## 声明
我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
## 协议
社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
[《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
[skywork-opensource@kunlun-inc.com]: mailto:skywork-opensource@kunlun-inc.com
================================================
FILE: README.md
================================================
<p align="center">
<img src="assets/logo2.png" alt="SkyReels Logo" width="50%">
</p>
<h1 align="center">SkyReels V2: Infinite-Length Film Generative Model</h1>
<p align="center">
📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9" target="_blank">Hugging Face</a> · 🤖 <a href="https://www.modelscope.cn/collections/SkyReels-V2-f665650130b144" target="_blank">ModelScope</a>
</p>
---
Welcome to the **SkyReels V2** repository! Here, you'll find the model weights and inference code for our infinite-length film generative models. To the best of our knowledge, it represents the first open-source video generative model employing **AutoRegressive Diffusion-Forcing architecture** that achieves the **SOTA performance** among publicly available models.
## 🔥🔥🔥 News!!
* Jan 29, 2026: 🎉 We launched the API for the SkyReels-V3 models on the [apifree.ai](https://www.apifree.ai/explore).
* Jan 29, 2026: 🎉 We release the inference code and model weights of [SkyReels-V3](https://github.com/SkyworkAI/SkyReels-V3).
* Jun 1, 2025: 🎉 We published the technical report, [SkyReels-Audio: Omni Audio-Conditioned Talking Portraits in Video Diffusion Transformers](https://arxiv.org/pdf/2506.00830).
* May 16, 2025: 🔥 We release the inference code for [video extension](#ve) and [start/end frame control](#se) in diffusion forcing model.
* Apr 24, 2025: 🔥 We release the 720P models, [SkyReels-V2-DF-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P) and [SkyReels-V2-I2V-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P). The former facilitates infinite-length autoregressive video generation, and the latter focuses on Image2Video synthesis.
* Apr 21, 2025: 👋 We release the inference code and model weights of [SkyReels-V2](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) Series Models and the video captioning model [SkyCaptioner-V1](https://huggingface.co/Skywork/SkyCaptioner-V1) .
* Apr 3, 2025: 🔥 We also release [SkyReels-A2](https://github.com/SkyworkAI/SkyReels-A2). This is an open-sourced controllable video generation framework capable of assembling arbitrary visual elements.
* Feb 18, 2025: 🔥 we released [SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1). This is an open-sourced and effective framework for portrait image animation.
* Feb 18, 2025: 🔥 We released [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
## 🎥 Demos
<table>
<tr>
<td align="center">
<video src="https://github.com/user-attachments/assets/f6f9f9a7-5d5f-433c-9d73-d8d593b7ad25" width="100%"></video>
</td>
<td align="center">
<video src="https://github.com/user-attachments/assets/0eb13415-f4d9-4aaf-bcd3-3031851109b9" width="100%"></video>
</td>
<td align="center">
<video src="https://github.com/user-attachments/assets/dcd16603-5bf4-4786-8e4d-1ed23889d07a" width="100%"></video>
</td>
</tr>
</table>
The demos above showcase 30-second videos generated using our SkyReels-V2 Diffusion Forcing model.
## 📑 TODO List
- [x] <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a>
- [x] Checkpoints of the 14B and 1.3B Models Series
- [x] Single-GPU & Multi-GPU Inference Code
- [x] <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a>: A Video Captioning Model
- [x] Prompt Enhancer
- [x] Diffusers integration
- [ ] Checkpoints of the 5B Models Series
- [ ] Checkpoints of the Camera Director Models
- [ ] Checkpoints of the Step & Guidance Distill Model
## 🚀 Quickstart
#### Installation
```shell
# clone the repository.
git clone https://github.com/SkyworkAI/SkyReels-V2
cd SkyReels-V2
# Install dependencies. Test environment uses Python 3.10.12.
pip install -r requirements.txt
```
#### Model Download
You can download our models from Hugging Face:
<table>
<thead>
<tr>
<th>Type</th>
<th>Model Variant</th>
<th>Recommended Height/Width/Frame</th>
<th>Link</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="5">Diffusion Forcing</td>
<td>1.3B-540P</td>
<td>544 * 960 * 97f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-1.3B-540P">ModelScope</a></td>
</tr>
<tr>
<td>5B-540P</td>
<td>544 * 960 * 97f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>5B-720P</td>
<td>720 * 1280 * 121f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>14B-540P</td>
<td>544 * 960 * 97f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-14B-540P">ModelScope</a></td>
</tr>
<tr>
<td>14B-720P</td>
<td>720 * 1280 * 121f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-14B-720P">ModelScope</a></td>
</tr>
<tr>
<td rowspan="5">Text-to-Video</td>
<td>1.3B-540P</td>
<td>544 * 960 * 97f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>5B-540P</td>
<td>544 * 960 * 97f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>5B-720P</td>
<td>720 * 1280 * 121f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>14B-540P</td>
<td>544 * 960 * 97f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-T2V-14B-540P">ModelScope</a></td>
</tr>
<tr>
<td>14B-720P</td>
<td>720 * 1280 * 121f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-T2V-14B-720P">ModelScope</a></td>
</tr>
<tr>
<td rowspan="5">Image-to-Video</td>
<td>1.3B-540P</td>
<td>544 * 960 * 97f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-1.3B-540P">ModelScope</a></td>
</tr>
<tr>
<td>5B-540P</td>
<td>544 * 960 * 97f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>5B-720P</td>
<td>720 * 1280 * 121f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>14B-540P</td>
<td>544 * 960 * 97f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-14B-540P">ModelScope</a></td>
</tr>
<tr>
<td>14B-720P</td>
<td>720 * 1280 * 121f</td>
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-14B-720P">ModelScope</a></td>
</tr>
<tr>
<td rowspan="3">Camera Director</td>
<td>5B-540P</td>
<td>544 * 960 * 97f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>5B-720P</td>
<td>720 * 1280 * 121f</td>
<td>Coming Soon</td>
</tr>
<tr>
<td>14B-720P</td>
<td>720 * 1280 * 121f</td>
<td>Coming Soon</td>
</tr>
</tbody>
</table>
After downloading, set the model path in your generation commands:
#### Single GPU Inference
- **Diffusion Forcing for Long Video Generation**
The <a href="https://arxiv.org/abs/2407.01392">**Diffusion Forcing**</a> version model allows us to generate Infinite-Length videos. This model supports both **text-to-video (T2V)** and **image-to-video (I2V)** tasks, and it can perform inference in both synchronous and asynchronous modes. Here we demonstrate 2 running scripts as examples for long video generation. If you want to adjust the inference parameters, e.g., the duration of video, inference mode, read the Note below first.
synchronous generation for 10s video
```shell
model_id=Skywork/SkyReels-V2-DF-14B-540P
# synchronous inference
python3 generate_video_df.py \
--model_id ${model_id} \
--resolution 540P \
--ar_step 0 \
--base_num_frames 97 \
--num_frames 257 \
--overlap_history 17 \
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
--addnoise_condition 20 \
--offload \
--teacache \
--use_ret_steps \
--teacache_thresh 0.3
```
asynchronous generation for 30s video
```shell
model_id=Skywork/SkyReels-V2-DF-14B-540P
# asynchronous inference
python3 generate_video_df.py \
--model_id ${model_id} \
--resolution 540P \
--ar_step 5 \
--causal_block_size 5 \
--base_num_frames 97 \
--num_frames 737 \
--overlap_history 17 \
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
--addnoise_condition 20 \
--offload
```
Text-to-video with `diffusers`:
```py
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-14B-540P-Diffusers",
vae=vae,
torch_dtype=torch.bfloat16
)
flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline = pipeline.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
output = pipeline(
prompt=prompt,
num_inference_steps=30,
height=544, # 720 for 720P
width=960, # 1280 for 720P
num_frames=97,
base_num_frames=97, # 121 for 720P
ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
causal_block_size=5, # Number of frames in each block for asynchronous processing
overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
addnoise_condition=20, # Improves consistency in long video generation
).frames[0]
export_to_video(output, "T2V.mp4", fps=24, quality=8)
```
Image-to-video with `diffusers`:
```py
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipeline, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipeline)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
```
> **Note**:
> - If you want to run the **image-to-video (I2V)** task, add `--image ${image_path}` to your command and it is also better to use **text-to-video (T2V)**-like prompt which includes some descriptions of the first-frame image.
> - For long video generation, you can just switch the `--num_frames`, e.g., `--num_frames 257` for 10s video, `--num_frames 377` for 15s video, `--num_frames 737` for 30s video, `--num_frames 1457` for 60s video. The number is not strictly aligned with the logical frame number for specified time duration, but it is aligned with some training parameters, which means it may perform better. When you use asynchronous inference with causal_block_size > 1, the `--num_frames` should be carefully set.
> - You can use `--ar_step 5` to enable asynchronous inference. When asynchronous inference, `--causal_block_size 5` is recommended while it is not supposed to be set for synchronous generation. REMEMBER that the frame latent number inputted into the model in every iteration, e.g., base frame latent number (e.g., (97-1)//4+1=25 for base_num_frames=97) and (e.g., (237-97-(97-17)x1+17-1)//4+1=20 for base_num_frames=97, num_frames=237, overlap_history=17) for the last iteration, MUST be divided by causal_block_size. If you find it too hard to calculate and set proper values, just use our recommended setting above :). Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
> - To reduce peak VRAM, just lower the `--base_num_frames`, e.g., to 77 or 57, while keeping the same generative length `--num_frames` you want to generate. This may slightly reduce video quality, and it should not be set too small.
> - `--addnoise_condition` is used to help smooth the long video generation by adding some noise to the clean condition. Too large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger ones, but it is recommended to not exceed 50.
> - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 51.2GB peak VRAM.
- **<span id="ve">Video Extention</span>**
```shell
model_id=Skywork/SkyReels-V2-DF-14B-540P
# video extention
python3 generate_video_df.py \
--model_id ${model_id} \
--resolution 540P \
--ar_step 0 \
--base_num_frames 97 \
--num_frames 120 \
--overlap_history 17 \
--prompt ${prompt} \
--addnoise_condition 20 \
--offload \
--use_ret_steps \
--teacache \
--teacache_thresh 0.3 \
--video_path ${video_path}
```
> **Note**:
> - When performing video extension, you need to pass the `--video_path ${video_path}` parameter to specify the video to be extended.
- **<span id="se">Start/End Frame Control</span>**
```shell
model_id=Skywork/SkyReels-V2-DF-14B-540P
# start/end frame control
python3 generate_video_df.py \
--model_id ${model_id} \
--resolution 540P \
--ar_step 0 \
--base_num_frames 97 \
--num_frames 97 \
--overlap_history 17 \
--prompt ${prompt} \
--addnoise_condition 20 \
--offload \
--use_ret_steps \
--teacache \
--teacache_thresh 0.3 \
--image ${image} \
--end_image ${end_image}
```
> **Note**:
> - When controlling the start and end frames, you need to pass the `--image ${image}` parameter to control the generation of the start frame and the `--end_image ${end_image}` parameter to control the generation of the end frame.
Video extension with `diffusers`:
```py
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_video
model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
video = load_video("input_video.mp4")
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0,
num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
# Total frames will be the number of frames of given video + 257
```
- **Text To Video & Image To Video**
```shell
# run Text-to-Video Generation
model_id=Skywork/SkyReels-V2-T2V-14B-540P
python3 generate_video.py \
--model_id ${model_id} \
--resolution 540P \
--num_frames 97 \
--guidance_scale 6.0 \
--shift 8.0 \
--fps 24 \
--prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
--offload \
--teacache \
--use_ret_steps \
--teacache_thresh 0.3
```
> **Note**:
> - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
> - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 43.4GB peak VRAM.
T2V models with `diffusers`:
```py
import torch
from diffusers import (
SkyReelsV2Pipeline,
UniPCMultistepScheduler,
AutoencoderKLWan,
)
from diffusers.utils import export_to_video
# Load the pipeline
# Available models:
# - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
vae = AutoencoderKLWan.from_pretrained(
"Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
subfolder="vae",
torch_dtype=torch.float32,
)
pipe = SkyReelsV2Pipeline.from_pretrained(
"Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
vae=vae,
torch_dtype=torch.bfloat16,
)
flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe = pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
output = pipe(
prompt=prompt,
num_inference_steps=50,
height=544,
width=960,
guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
num_frames=97,
).frames[0]
export_to_video(output, "video.mp4", fps=24, quality=8)
```
I2V models with `diffusers`:
```py
import torch
from diffusers import (
SkyReelsV2ImageToVideoPipeline,
UniPCMultistepScheduler,
AutoencoderKLWan,
)
from diffusers.utils import export_to_video
from PIL import Image
# Load the pipeline
# Available models:
# - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
vae = AutoencoderKLWan.from_pretrained(
"Skywork/SkyReels-V2-I2V-14B-720P-Diffusers",
subfolder="vae",
torch_dtype=torch.float32,
)
pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained(
"Skywork/SkyReels-V2-I2V-14B-720P-Diffusers",
vae=vae,
torch_dtype=torch.bfloat16,
)
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe = pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
image = Image.open("path/to/image.png")
output = pipe(
image=image,
prompt=prompt,
num_inference_steps=50,
height=544,
width=960,
guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V
num_frames=97,
).frames[0]
export_to_video(output, "video.mp4", fps=24, quality=8)
```
- **Prompt Enhancer**
The prompt enhancer is implemented based on <a href="https://huggingface.co/Qwen/Qwen2.5-32B-Instruct">Qwen2.5-32B-Instruct</a> and is utilized via the `--prompt_enhancer` parameter. It works ideally for short prompts, while for long prompts, it might generate an excessively lengthy prompt that could lead to over-saturation in the generative video. Note the peak memory of GPU is 64G+ if you use `--prompt_enhancer`. If you want to obtain the enhanced prompt separately, you can also run the prompt_enhancer script separately for testing. The steps are as follows:
```shell
cd skyreels_v2_infer/pipelines
python3 prompt_enhancer.py --prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface."
```
> **Note**:
> - `--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter.
**Advanced Configuration Options**
Below are the key parameters you can customize for video generation:
| Parameter | Recommended Value | Description |
|:----------------------:|:---------:|:-----------------------------------------:|
| --prompt | | Text description for generating your video |
| --image | | Path to input image for image-to-video generation |
| --resolution | 540P or 720P | Output video resolution (select based on model type) |
| --num_frames | 97 or 121 | Total frames to generate (**97 for 540P models**, **121 for 720P models**) |
| --inference_steps | 50 | Number of denoising steps |
| --fps | 24 | Frames per second in the output video |
| --shift | 8.0 or 5.0 | Flow matching scheduler parameter (**8.0 for T2V**, **5.0 for I2V**) |
| --guidance_scale | 6.0 or 5.0 | Controls text adherence strength (**6.0 for T2V**, **5.0 for I2V**) |
| --seed | | Fixed seed for reproducible results (omit for random generation) |
| --offload | True | Offloads model components to CPU to reduce VRAM usage (recommended) |
| --use_usp | True | Enables multi-GPU acceleration with xDiT USP |
| --outdir | ./video_out | Directory where generated videos will be saved |
| --prompt_enhancer | True | Expand the prompt into a more detailed description |
| --teacache | False | Enables teacache for faster inference |
| --teacache_thresh | 0.2 | Higher speedup will cause to worse quality |
| --use_ret_steps | False | Retention Steps for teacache |
**Diffusion Forcing Additional Parameters**
| Parameter | Recommended Value | Description |
|:----------------------:|:---------:|:-----------------------------------------:|
| --ar_step | 0 | Controls asynchronous inference (0 for synchronous mode) |
| --base_num_frames | 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) |
| --overlap_history | 17 | Number of frames to overlap for smooth transitions in long videos |
| --addnoise_condition | 20 | Improves consistency in long video generation |
| --causal_block_size | 5 | Recommended when using asynchronous inference (--ar_step > 0) |
--video_path | | Path to input video for video extension |
--end_image | | Path to input image for end frame control |
#### Multi-GPU inference using xDiT USP
We use [xDiT](https://github.com/xdit-project/xDiT) USP to accelerate inference. For example, to generate a video with 2 GPUs, you can use the following command:
- **Diffusion Forcing**
```shell
model_id=Skywork/SkyReels-V2-DF-14B-540P
# diffusion forcing synchronous inference
torchrun --nproc_per_node=2 generate_video_df.py \
--model_id ${model_id} \
--resolution 540P \
--ar_step 0 \
--base_num_frames 97 \
--num_frames 257 \
--overlap_history 17 \
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
--addnoise_condition 20 \
--use_usp \
--offload \
--seed 42
```
- **Text To Video & Image To Video**
```shell
# run Text-to-Video Generation
model_id=Skywork/SkyReels-V2-T2V-14B-540P
torchrun --nproc_per_node=2 generate_video.py \
--model_id ${model_id} \
--resolution 540P \
--num_frames 97 \
--guidance_scale 6.0 \
--shift 8.0 \
--fps 24 \
--offload \
--prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
--use_usp \
--seed 42
```
> **Note**:
> - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
## Contents
- [Abstract](#abstract)
- [Methodology of SkyReels-V2](#methodology-of-skyreels-v2)
- [Key Contributions of SkyReels-V2](#key-contributions-of-skyreels-v2)
- [Video Captioner](#video-captioner)
- [Reinforcement Learning](#reinforcement-learning)
- [Diffusion Forcing](#diffusion-forcing)
- [High-Quality Supervised Fine-Tuning(SFT)](#high-quality-supervised-fine-tuning-sft)
- [Performance](#performance)
- [Acknowledgements](#acknowledgements)
- [Citation](#citation)
---
## Abstract
Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation.
To address these limitations, we introduce SkyReels-V2, the world's first infinite-length film generative model using a Diffusion Forcing framework. Our approach synergizes Multi-modal Large Language Models (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing techniques to achieve comprehensive optimization. Beyond its technical innovations, SkyReels-V2 enables multiple practical applications, including Story Generation, Image-to-Video Synthesis, Camera Director functionality, and multi-subject consistent video generation through our <a href="https://github.com/SkyworkAI/SkyReels-A2">Skyreels-A2</a> system.
## Methodology of SkyReels-V2
The SkyReels-V2 methodology consists of several interconnected components. It starts with a comprehensive data processing pipeline that prepares various quality training data. At its core is the Video Captioner architecture, which provides detailed annotations for video content. The system employs a multi-task pretraining strategy to build fundamental video generation capabilities. Post-training optimization includes Reinforcement Learning to enhance motion quality, Diffusion Forcing Training for generating extended videos, and High-quality Supervised Fine-Tuning (SFT) stages for visual refinement. The model runs on optimized computational infrastructure for efficient training and inference. SkyReels-V2 supports multiple applications, including Story Generation, Image-to-Video Synthesis, Camera Director functionality, and Elements-to-Video Generation.
<p align="center">
<img src="assets/main_pipeline.jpg" alt="mainpipeline" width="100%">
</p>
## Key Contributions of SkyReels-V2
#### Video Captioner
<a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> serves as our video captioning model for data annotation. This model is trained on the captioning result from the base model <a href="https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct">Qwen2.5-VL-72B-Instruct</a> and the sub-expert captioners on a balanced video data. The balanced video data is a carefully curated dataset of approximately 2 million videos to ensure conceptual balance and annotation quality. Built upon the <a href="https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct">Qwen2.5-VL-7B-Instruct</a> foundation model, <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> is fine-tuned to enhance performance in domain-specific video captioning tasks. To compare the performance with the SOTA models, we conducted a manual assessment of accuracy across different captioning fields using a test set of 1,000 samples. The proposed <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> achieves the highest average accuracy among the baseline models, and show a dramatic result in the shot related fields
<p align="center">
<table align="center">
<thead>
<tr>
<th>model</th>
<th><a href="https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct">Qwen2.5-VL-7B-Ins.</a></th>
<th><a href="https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct">Qwen2.5-VL-72B-Ins.</a></th>
<th><a href="https://huggingface.co/omni-research/Tarsier2-Recap-7b">Tarsier2-Recap-7b</a></th>
<th><a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</th>
</tr>
</thead>
<tbody>
<tr>
<td>Avg accuracy</td>
<td>51.4%</td>
<td>58.7%</td>
<td>49.4%</td>
<td><strong>76.3%</strong></td>
</tr>
<tr>
<td>shot type</td>
<td>76.8%</td>
<td>82.5%</td>
<td>60.2%</td>
<td><strong>93.7%</strong></td>
</tr>
<tr>
<td>shot angle</td>
<td>60.0%</td>
<td>73.7%</td>
<td>52.4%</td>
<td><strong>89.8%</strong></td>
</tr>
<tr>
<td>shot position</td>
<td>28.4%</td>
<td>32.7%</td>
<td>23.6%</td>
<td><strong>83.1%</strong></td>
</tr>
<tr>
<td>camera motion</td>
<td>62.0%</td>
<td>61.2%</td>
<td>45.3%</td>
<td><strong>85.3%</strong></td>
</tr>
<tr>
<td>expression</td>
<td>43.6%</td>
<td>51.5%</td>
<td>54.3%</td>
<td><strong>68.8%</strong></td>
</tr>
<tr>
<td colspan="5" style="text-align: center; border-bottom: 1px solid #ddd; padding: 8px;"></td>
</tr>
<tr>
<td>TYPES_type</td>
<td>43.5%</td>
<td>49.7%</td>
<td>47.6%</td>
<td><strong>82.5%</strong></td>
</tr>
<tr>
<td>TYPES_sub_type</td>
<td>38.9%</td>
<td>44.9%</td>
<td>45.9%</td>
<td><strong>75.4%</strong></td>
</tr>
<tr>
<td>appearance</td>
<td>40.9%</td>
<td>52.0%</td>
<td>45.6%</td>
<td><strong>59.3%</strong></td>
</tr>
<tr>
<td>action</td>
<td>32.4%</td>
<td>52.0%</td>
<td><strong>69.8%</strong></td>
<td>68.8%</td>
</tr>
<tr>
<td>position</td>
<td>35.4%</td>
<td>48.6%</td>
<td>45.5%</td>
<td><strong>57.5%</strong></td>
</tr>
<tr>
<td>is_main_subject</td>
<td>58.5%</td>
<td>68.7%</td>
<td>69.7%</td>
<td><strong>80.9%</strong></td>
</tr>
<tr>
<td>environment</td>
<td>70.4%</td>
<td><strong>72.7%</strong></td>
<td>61.4%</td>
<td>70.5%</td>
</tr>
<tr>
<td>lighting</td>
<td>77.1%</td>
<td><strong>80.0%</strong></td>
<td>21.2%</td>
<td>76.5%</td>
</tr>
</tbody>
</table>
</p>
#### Reinforcement Learning
Inspired by the previous success in LLM, we propose to enhance the performance of the generative model by Reinforcement Learning. Specifically, we focus on the motion quality because we find that the main drawback of our generative model is:
- the generative model does not handle well with large, deformable motions.
- the generated videos may violate the physical law.
To avoid the degradation in other metrics, such as text alignment and video quality, we ensure the preference data pairs have comparable text alignment and video quality, while only the motion quality varies. This requirement poses greater challenges in obtaining preference annotations due to the inherently higher costs of human annotation. To address this challenge, we propose a semi-automatic pipeline that strategically combines automatically generated motion pairs and human annotation results. This hybrid approach not only enhances the data scale but also improves alignment with human preferences through curated quality control. Leveraging this enhanced dataset, we first train a specialized reward model to capture the generic motion quality differences between paired samples. This learned reward function subsequently guides the sample selection process for Direct Preference Optimization (DPO), enhancing the motion quality of the generative model.
#### Diffusion Forcing
We introduce the Diffusion Forcing Transformer to unlock our model’s ability to generate long videos. Diffusion Forcing is a training and sampling strategy where each token is assigned an independent noise level. This allows tokens to be denoised according to arbitrary, per-token schedules. Conceptually, this approach functions as a form of partial masking: a token with zero noise is fully unmasked, while complete noise fully masks it. Diffusion Forcing trains the model to "unmask" any combination of variably noised tokens, using the cleaner tokens as conditional information to guide the recovery of noisy ones. Building on this, our Diffusion Forcing Transformer can extend video generation indefinitely based on the last frames of the previous segment. Note that the synchronous full sequence diffusion is a special case of Diffusion Forcing, where all tokens share the same noise level. This relationship allows us to fine-tune the Diffusion Forcing Transformer from a full-sequence diffusion model.
#### High-Quality Supervised Fine-Tuning (SFT)
We implement two sequential high-quality supervised fine-tuning (SFT) stages at 540p and 720p resolutions respectively, with the initial SFT phase conducted immediately after pretraining but prior to reinforcement learning (RL) stage.This first-stage SFT serves as a conceptual equilibrium trainer, building upon the foundation model’s pretraining outcomes that utilized only fps24 video data, while strategically removing FPS embedding components to streamline thearchitecture. Trained with the high-quality concept-balanced samples, this phase establishes optimized initialization parameters for subsequent training processes. Following this, we execute a secondary high-resolution SFT at 720p after completing the diffusion forcing stage, incorporating identical loss formulations and the higher-quality concept-balanced datasets by the manually filter. This final refinement phase focuses on resolution increase such that the overall video quality will be further enhanced.
## Performance
To comprehensively evaluate our proposed method, we construct the SkyReels-Bench for human assessment and leveraged the open-source <a href="https://github.com/Vchitect/VBench">V-Bench</a> for automated evaluation. This allows us to compare our model with the state-of-the-art (SOTA) baselines, including both open-source and proprietary models.
#### Human Evaluation
For human evaluation, we design SkyReels-Bench with 1,020 text prompts, systematically assessing three dimensions: Instruction Adherence, Motion Quality, Consistency and Visual Quality. This benchmark is designed to evaluate both text-to-video (T2V) and image-to-video (I2V) generation models, providing comprehensive assessment across different generation paradigms. To ensure fairness, all models were evaluated under default settings with consistent resolutions, and no post-generation filtering was applied.
- Text To Video Models
<p align="center">
<table align="center">
<thead>
<tr>
<th>Model Name</th>
<th>Average</th>
<th>Instruction Adherence</th>
<th>Consistency</th>
<th>Visual Quality</th>
<th>Motion Quality</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://runwayml.com/research/introducing-gen-3-alpha">Runway-Gen3 Alpha</a></td>
<td>2.53</td>
<td>2.19</td>
<td>2.57</td>
<td>3.23</td>
<td>2.11</td>
</tr>
<tr>
<td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
<td>2.82</td>
<td>2.64</td>
<td>2.81</td>
<td>3.20</td>
<td>2.61</td>
</tr>
<tr>
<td><a href="https://klingai.com">Kling-1.6 STD Mode</a></td>
<td>2.99</td>
<td>2.77</td>
<td>3.05</td>
<td>3.39</td>
<td><strong>2.76</strong></td>
</tr>
<tr>
<td><a href="https://hailuoai.video">Hailuo-01</a></td>
<td>3.0</td>
<td>2.8</td>
<td>3.08</td>
<td>3.29</td>
<td>2.74</td>
</tr>
<tr>
<td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
<td>3.12</td>
<td>2.91</td>
<td>3.31</td>
<td><strong>3.54</strong></td>
<td>2.71</td>
</tr>
<tr>
<td>SkyReels-V2</td>
<td><strong>3.14</strong></td>
<td><strong>3.15</strong></td>
<td><strong>3.35</strong></td>
<td>3.34</td>
<td>2.74</td>
</tr>
</tbody>
</table>
</p>
The evaluation demonstrates that our model achieves significant advancements in **instruction adherence (3.15)** compared to baseline methods, while maintaining competitive performance in **motion quality (2.74)** without sacrificing the **consistency (3.35)**.
- Image To Video Models
<p align="center">
<table align="center">
<thead>
<tr>
<th>Model</th>
<th>Average</th>
<th>Instruction Adherence</th>
<th>Consistency</th>
<th>Visual Quality</th>
<th>Motion Quality</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
<td>2.84</td>
<td>2.97</td>
<td>2.95</td>
<td>2.87</td>
<td>2.56</td>
</tr>
<tr>
<td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
<td>2.85</td>
<td>3.10</td>
<td>2.81</td>
<td>3.00</td>
<td>2.48</td>
</tr>
<tr>
<td><a href="https://hailuoai.video">Hailuo-01</a></td>
<td>3.05</td>
<td>3.31</td>
<td>2.58</td>
<td>3.55</td>
<td>2.74</td>
</tr>
<tr>
<td><a href="https://klingai.com">Kling-1.6 Pro Mode</a></td>
<td>3.4</td>
<td>3.56</td>
<td>3.03</td>
<td>3.58</td>
<td>3.41</td>
</tr>
<tr>
<td><a href="https://runwayml.com/research/introducing-runway-gen-4">Runway-Gen4</a></td>
<td>3.39</td>
<td>3.75</td>
<td>3.2</td>
<td>3.4</td>
<td>3.37</td>
</tr>
<tr>
<td>SkyReels-V2-DF</td>
<td>3.24</td>
<td>3.64</td>
<td>3.21</td>
<td>3.18</td>
<td>2.93</td>
</tr>
<tr>
<td>SkyReels-V2-I2V</td>
<td>3.29</td>
<td>3.42</td>
<td>3.18</td>
<td>3.56</td>
<td>3.01</td>
</tr>
</tbody>
</table>
</p>
Our results demonstrate that both **SkyReels-V2-I2V (3.29)** and **SkyReels-V2-DF (3.24)** achieve state-of-the-art performance among open-source models, significantly outperforming HunyuanVideo-13B (2.84) and Wan2.1-14B (2.85) across all quality dimensions. With an average score of 3.29, SkyReels-V2-I2V demonstrates comparable performance to proprietary models Kling-1.6 (3.4) and Runway-Gen4 (3.39).
#### VBench
To objectively compare SkyReels-V2 Model against other leading open-source Text-To-Video models, we conduct comprehensive evaluations using the public benchmark <a href="https://github.com/Vchitect/VBench">V-Bench</a>. Our evaluation specifically leverages the benchmark’s longer version prompt. For fair comparison with baseline models, we strictly follow their recommended setting for inference.
<p align="center">
<table align="center">
<thead>
<tr>
<th>Model</th>
<th>Total Score</th>
<th>Quality Score</th>
<th>Semantic Score</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://github.com/hpcaitech/Open-Sora">OpenSora 2.0</a></td>
<td>81.5 %</td>
<td>82.1 %</td>
<td>78.2 %</td>
</tr>
<tr>
<td><a href="https://github.com/THUDM/CogVideo">CogVideoX1.5-5B</a></td>
<td>80.3 %</td>
<td>80.9 %</td>
<td>77.9 %</td>
</tr>
<tr>
<td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
<td>82.7 %</td>
<td>84.4 %</td>
<td>76.2 %</td>
</tr>
<tr>
<td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
<td>83.7 %</td>
<td>84.2 %</td>
<td><strong>81.4 %</strong></td>
</tr>
<tr>
<td>SkyReels-V2</td>
<td><strong>83.9 %</strong></td>
<td><strong>84.7 %</strong></td>
<td>80.8 %</td>
</tr>
</tbody>
</table>
</p>
The VBench results demonstrate that SkyReels-V2 outperforms all compared models including HunyuanVideo-13B and Wan2.1-14B, With the highest **total score (83.9%)** and **quality score (84.7%)**. In this evaluation, the semantic score is slightly lower than Wan2.1-14B, while we outperform Wan2.1-14B in human evaluations, with the primary gap attributed to V-Bench’s insufficient evaluation of shot-scenario semantic adherence.
## Acknowledgements
We would like to thank the contributors of <a href="https://github.com/Wan-Video/Wan2.1">Wan 2.1</a>, <a href="https://github.com/xdit-project/xDiT">XDit</a> and <a href="https://qwenlm.github.io/blog/qwen2.5/">Qwen 2.5</a> repositories, for their open research and contributions.
## Citation
```bibtex
@misc{chen2025skyreelsv2infinitelengthfilmgenerative,
title={SkyReels-V2: Infinite-length Film Generative Model},
author={Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
year={2025},
eprint={2504.13074},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.13074},
}
```
================================================
FILE: generate_video.py
================================================
import argparse
import gc
import os
import random
import time
import imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import Image2VideoPipeline
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop
from skyreels_v2_infer.pipelines import Text2VideoPipeline
MODEL_ID_CONFIG = {
"text2video": [
"Skywork/SkyReels-V2-T2V-14B-540P",
"Skywork/SkyReels-V2-T2V-14B-720P",
],
"image2video": [
"Skywork/SkyReels-V2-I2V-1.3B-540P",
"Skywork/SkyReels-V2-I2V-14B-540P",
"Skywork/SkyReels-V2-I2V-14B-720P",
],
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, default="video_out")
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-T2V-14B-540P")
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
parser.add_argument("--num_frames", type=int, default=97)
parser.add_argument("--image", type=str, default=None)
parser.add_argument("--guidance_scale", type=float, default=6.0)
parser.add_argument("--shift", type=float, default=8.0)
parser.add_argument("--inference_steps", type=int, default=30)
parser.add_argument("--use_usp", action="store_true")
parser.add_argument("--offload", action="store_true")
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--prompt",
type=str,
default="A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface.",
)
parser.add_argument("--prompt_enhancer", action="store_true")
parser.add_argument("--teacache", action="store_true")
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
parser.add_argument(
"--use_ret_steps",
action="store_true",
help="Using Retention Steps will result in faster generation speed and better generation quality.")
args = parser.parse_args()
args.model_id = download_model(args.model_id)
print("model_id:", args.model_id)
assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
if args.seed is None:
random.seed(time.time())
args.seed = int(random.randrange(4294967294))
if args.resolution == "540P":
height = 544
width = 960
elif args.resolution == "720P":
height = 720
width = 1280
else:
raise ValueError(f"Invalid resolution: {args.resolution}")
image = load_image(args.image).convert("RGB") if args.image else None
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
local_rank = 0
if args.use_usp:
assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
import torch.distributed as dist
dist.init_process_group("nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(dist.get_rank())
device = "cuda"
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
prompt_input = args.prompt
if args.prompt_enhancer and args.image is None:
print(f"init prompt enhancer")
prompt_enhancer = PromptEnhancer()
prompt_input = prompt_enhancer(prompt_input)
print(f"enhanced prompt: {prompt_input}")
del prompt_enhancer
gc.collect()
torch.cuda.empty_cache()
if image is None:
assert "T2V" in args.model_id, f"check model_id:{args.model_id}"
print("init text2video pipeline")
pipe = Text2VideoPipeline(
model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
)
else:
assert "I2V" in args.model_id, f"check model_id:{args.model_id}"
print("init img2video pipeline")
pipe = Image2VideoPipeline(
model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
)
args.image = load_image(args.image)
image_width, image_height = args.image.size
if image_height > image_width:
height, width = width, height
args.image = resizecrop(args.image, height, width)
if args.teacache:
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps,
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
ckpt_dir=args.model_id)
kwargs = {
"prompt": prompt_input,
"negative_prompt": negative_prompt,
"num_frames": args.num_frames,
"num_inference_steps": args.inference_steps,
"guidance_scale": args.guidance_scale,
"shift": args.shift,
"generator": torch.Generator(device="cuda").manual_seed(args.seed),
"height": height,
"width": width,
}
if image is not None:
kwargs["image"] = args.image.convert("RGB")
save_dir = os.path.join("result", args.outdir)
os.makedirs(save_dir, exist_ok=True)
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
print(f"infer kwargs:{kwargs}")
video_frames = pipe(**kwargs)[0]
if local_rank == 0:
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
output_path = os.path.join(save_dir, video_out_file)
imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"])
================================================
FILE: generate_video_df.py
================================================
import argparse
import gc
import os
import random
import time
import imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines.image2video_pipeline import resizecrop
from moviepy.editor import VideoFileClip
def get_video_num_frames_moviepy(video_path):
with VideoFileClip(video_path) as clip:
num_frames = 0
for _ in clip.iter_frames():
num_frames += 1
return clip.size, num_frames
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, default="diffusion_forcing")
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P")
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
parser.add_argument("--num_frames", type=int, default=97)
parser.add_argument("--image", type=str, default=None)
parser.add_argument("--end_image", type=str, default=None)
parser.add_argument("--video_path", type=str, default='')
parser.add_argument("--ar_step", type=int, default=0)
parser.add_argument("--causal_attention", action="store_true")
parser.add_argument("--causal_block_size", type=int, default=1)
parser.add_argument("--base_num_frames", type=int, default=97)
parser.add_argument("--overlap_history", type=int, default=None)
parser.add_argument("--addnoise_condition", type=int, default=0)
parser.add_argument("--guidance_scale", type=float, default=6.0)
parser.add_argument("--shift", type=float, default=8.0)
parser.add_argument("--inference_steps", type=int, default=30)
parser.add_argument("--use_usp", action="store_true")
parser.add_argument("--offload", action="store_true")
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--prompt",
type=str,
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
)
parser.add_argument("--prompt_enhancer", action="store_true")
parser.add_argument("--teacache", action="store_true")
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
parser.add_argument(
"--use_ret_steps",
action="store_true",
help="Using Retention Steps will result in faster generation speed and better generation quality.")
args = parser.parse_args()
args.model_id = download_model(args.model_id)
print("model_id:", args.model_id)
assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
if args.seed is None:
random.seed(time.time())
args.seed = int(random.randrange(4294967294))
if args.resolution == "540P":
height = 544
width = 960
elif args.resolution == "720P":
height = 720
width = 1280
else:
raise ValueError(f"Invalid resolution: {args.resolution}")
num_frames = args.num_frames
fps = args.fps
if num_frames > args.base_num_frames:
assert (
args.overlap_history is not None
), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
if args.addnoise_condition > 60:
print(
f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
)
guidance_scale = args.guidance_scale
shift = args.shift
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_dir = os.path.join("result", args.outdir)
os.makedirs(save_dir, exist_ok=True)
local_rank = 0
if args.use_usp:
assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
import torch.distributed as dist
dist.init_process_group("nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(dist.get_rank())
device = "cuda"
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
prompt_input = args.prompt
if args.prompt_enhancer and args.image is None:
print(f"init prompt enhancer")
prompt_enhancer = PromptEnhancer()
prompt_input = prompt_enhancer(prompt_input)
print(f"enhanced prompt: {prompt_input}")
del prompt_enhancer
gc.collect()
torch.cuda.empty_cache()
pipe = DiffusionForcingPipeline(
args.model_id,
dit_path=args.model_id,
device=torch.device("cuda"),
weight_dtype=torch.bfloat16,
use_usp=args.use_usp,
offload=args.offload,
)
if args.causal_attention:
pipe.transformer.set_ar_attention(args.causal_block_size)
if args.teacache:
if args.ar_step > 0:
num_steps = args.inference_steps + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
print('num_steps:', num_steps)
else:
num_steps = args.inference_steps
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps,
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
ckpt_dir=args.model_id)
print(f"prompt:{prompt_input}")
print(f"guidance_scale:{guidance_scale}")
if os.path.exists(args.video_path):
(v_width, v_height), input_num_frames = get_video_num_frames_moviepy(args.video_path)
assert input_num_frames >= args.overlap_history, "The input video is too short."
if v_height > v_width:
width, height = height, width
video_frames = pipe.extend_video(
prompt=prompt_input,
negative_prompt=negative_prompt,
prefix_video_path=args.video_path,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=args.inference_steps,
shift=shift,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(args.seed),
overlap_history=args.overlap_history,
addnoise_condition=args.addnoise_condition,
base_num_frames=args.base_num_frames,
ar_step=args.ar_step,
causal_block_size=args.causal_block_size,
fps=fps,
)[0]
else:
if args.image:
args.image = load_image(args.image)
image_width, image_height = args.image.size
if image_height > image_width:
height, width = width, height
args.image = resizecrop(args.image, height, width)
if args.end_image:
args.end_image = load_image(args.end_image)
args.end_image = resizecrop(args.end_image, height, width)
image = args.image.convert("RGB") if args.image else None
end_image = args.end_image.convert("RGB") if args.end_image else None
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
video_frames = pipe(
prompt=prompt_input,
negative_prompt=negative_prompt,
image=image,
end_image=end_image,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=args.inference_steps,
shift=shift,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(args.seed),
overlap_history=args.overlap_history,
addnoise_condition=args.addnoise_condition,
base_num_frames=args.base_num_frames,
ar_step=args.ar_step,
causal_block_size=args.causal_block_size,
fps=fps,
)[0]
if local_rank == 0:
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
output_path = os.path.join(save_dir, video_out_file)
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
================================================
FILE: requirements.txt
================================================
torch==2.5.1
torchvision==0.20.1
opencv-python==4.10.0.84
diffusers>=0.31.0
transformers==4.49.0
tokenizers==0.21.1
accelerate==1.6.0
tqdm
imageio
easydict
ftfy
dashscope
imageio-ffmpeg
flash_attn
numpy>=1.23.5,<2
xfuser
================================================
FILE: skycaptioner_v1/README.md
================================================
# SkyCaptioner-V1: A Structural Video Captioning Model
<p align="center">
📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/Skywork/SkyCaptioner-V1" target="_blank">Hugging Face</a> · 🤖 <a href="https://modelscope.cn/collections/SkyReels-V2-f665650130b144">ModelScope</a> · 🚀 <a href="https://huggingface.co/spaces/Skywork/SkyCaptioner-V1">Demo</a>
</p>
---
Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural video captioning model weights and inference code for our video captioner that labels the video data efficiently and comprehensively.
## 🔥🔥🔥 News!!
* May 07, 2025: 🚀 Added a web demo implementation based on Gradio and the [online demo](https://huggingface.co/spaces/Skywork/SkyCaptioner-V1) is now available!
* Apr 21, 2025: 👋 We release the [vllm](https://github.com/vllm-project/vllm) batch inference code for SkyCaptioner-V1 Model and caption fusion inference code.
* Apr 21, 2025: 👋 We release the first shot-aware video captioning model [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). For more details, please check our [paper](https://arxiv.org/pdf/2504.13074).
## 📑 TODO List
- SkyCaptioner-V1
- [x] Checkpoints
- [x] Batch Inference Code
- [x] Caption Fusion Method
- [x] Web Demo (Gradio)
## 🌟 Overview
SkyCaptioner-V1 is a structural video captioning model designed to generate high-quality, structural descriptions for video data. It integrates specialized sub-expert models and multimodal large language models (MLLMs) with human annotations to address the limitations of general captioners in capturing professional film-related details. Key aspects include:
1. **Structural Representation**: Combines general video descriptions (from MLLMs) with sub-expert captioner (e.g., shot types,shot angles, shot positions, camera motions.) and human annotations.
2. **Knowledge Distillation**: Distills expertise from sub-expert captioners into a unified model.
3. **Application Flexibility**: Generates dense captions for text-to-video (T2V) and concise prompts for image-to-video (I2V) tasks.
## 🔑 Key Features
### Structural Captioning Framework
Our Video Captioning model captures multi-dimensional details:
* **Subjects**: Appearance, action, expression, position, and hierarchical categorization.
* **Shot Metadata**: Shot type (e.g., close-up, long shot), shot angle, shot position, camera motion, environment, lighting, etc.
### Sub-Expert Integration
* **Shot Captioner**: Classifies shot type, angle, and position with high precision.
* **Expression Captioner**: Analyzes facial expressions, emotion intensity, and temporal dynamics.
* **Camera Motion Captioner**: Tracks 6DoF camera movements and composite motion types,
### Training Pipeline
* Trained on \~2M high-quality, concept-balanced videos curated from 10M raw samples.
* Fine-tuned on Qwen2.5-VL-7B-Instruct with a global batch size of 512 across 32 A800 GPUs.
* Optimized using AdamW (learning rate: 1e-5) for 2 epochs.
### Dynamic Caption Fusion:
* Adapts output length based on application (T2V/I2V).
* Employs LLM Model to fusion structural fields to get a natural and fluency caption for downstream tasks.
## 📊 Benchmark Results
SkyCaptioner-V1 demonstrates significant improvements over existing models in key film-specific captioning tasks, particularly in **shot-language understanding** and **domain-specific precision**. The differences stem from its structural architecture and expert-guided training:
1. **Superior Shot-Language Understanding**:
* Our Captioner model outperforms Qwen2.5-VL-72B with +11.2% in shot type, +16.1% in shot angle, and +50.4% in shot position accuracy. Because SkyCaptioner-V1’s specialized shot classifiers outperform generalist MLLMs, which lack film-domain fine-tuning.
* +28.5% accuracy in camera motion vs. Tarsier2-recap-7B (88.8% vs. 41.5%):
Its 6DoF motion analysis and active learning pipeline address ambiguities in composite motions (e.g., tracking + panning) that challenge generic captioners.
2. **High domain-specific precision**:
* Expression accuracy: 68.8% vs. 54.3% (Tarsier2-recap-7B), leveraging temporal-aware S2D frameworks to capture dynamic facial changes.
<p align="center">
<table align="center">
<thead>
<tr>
<th>Metric</th>
<th>Qwen2.5-VL-7B-Ins.</th>
<th>Qwen2.5-VL-72B-Ins.</th>
<th>Tarsier2-recap-7B</th>
<th>SkyCaptioner-V1</th>
</tr>
</thead>
<tbody>
<tr>
<td>Avg accuracy</td>
<td>51.4%</td>
<td>58.7%</td>
<td>49.4%</td>
<td><strong>76.3%</strong></td>
</tr>
<tr>
<td>shot type</td>
<td>76.8%</td>
<td>82.5%</td>
<td>60.2%</td>
<td><strong>93.7%</strong></td>
</tr>
<tr>
<td>shot angle</td>
<td>60.0%</td>
<td>73.7%</td>
<td>52.4%</td>
<td><strong>89.8%</strong></td>
</tr>
<tr>
<td>shot position</td>
<td>28.4%</td>
<td>32.7%</td>
<td>23.6%</td>
<td><strong>83.1%</strong></td>
</tr>
<tr>
<td>camera motion</td>
<td>62.0%</td>
<td>61.2%</td>
<td>45.3%</td>
<td><strong>85.3%</strong></td>
</tr>
<tr>
<td>expression</td>
<td>43.6%</td>
<td>51.5%</td>
<td>54.3%</td>
<td><strong>68.8%</strong></td>
</tr>
<tr>
<td>TYPES_type</td>
<td>43.5%</td>
<td>49.7%</td>
<td>47.6%</td>
<td><strong>82.5%</strong></td>
</tr>
<tr>
<td>TYPES_sub_type</td>
<td>38.9%</td>
<td>44.9%</td>
<td>45.9%</td>
<td><strong>75.4%</strong></td>
</tr>
<tr>
<td>appearance</td>
<td>40.9%</td>
<td>52.0%</td>
<td>45.6%</td>
<td><strong>59.3%</strong></td>
</tr>
<tr>
<td>action</td>
<td>32.4%</td>
<td>52.0%</td>
<td><strong>69.8%</strong></td>
<td>68.8%</td>
</tr>
<tr>
<td>position</td>
<td>35.4%</td>
<td>48.6%</td>
<td>45.5%</td>
<td><strong>57.5%</strong></td>
</tr>
<tr>
<td>is_main_subject</td>
<td>58.5%</td>
<td>68.7%</td>
<td>69.7%</td>
<td><strong>80.9%</strong></td>
</tr>
<tr>
<td>environment</td>
<td>70.4%</td>
<td><strong>72.7%</strong></td>
<td>61.4%</td>
<td>70.5%</td>
</tr>
<tr>
<td>lighting</td>
<td>77.1%</td>
<td><strong>80.0%</strong></td>
<td>21.2%</td>
<td>76.5%</td>
</tr>
</tbody>
</table>
</p>
## 📦 Model Downloads
Our SkyCaptioner-V1 model can be downloaded from [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1).
We use [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) as our caption fusion model to intelligently combine structured caption fields, producing either dense or sparse final captions depending on application requirements.
```shell
# download SkyCaptioner-V1
huggingface-cli download Skywork/SkyCaptioner-V1 --local-dir /path/to/your_local_model_path
# download Qwen2.5-32B-Instruct
huggingface-cli download Qwen/Qwen2.5-32B-Instruct --local-dir /path/to/your_local_model_path2
```
## 🛠️ Running Guide
Begin by cloning the repository:
```shell
git clone https://github.com/SkyworkAI/SkyReels-V2
cd skycaptioner_v1
```
### Installation Guide for Linux
We recommend Python 3.10 and CUDA version 12.2 for the manual installation.
```shell
pip install -r requirements.txt
```
### Running Command
#### Get Structural Caption by SkyCaptioner-V1
```shell
export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
python scripts/vllm_struct_caption.py \
--model_path ${SkyCaptioner_V1_Model_PATH} \
--input_csv "./examples/test.csv" \
--out_csv "./examples/test_result.csv" \
--tp 1 \
--bs 4
```
#### T2V/I2V Caption Fusion by Qwen2.5-32B-Instruct Model
```shell
export LLM_MODEL_PATH="/path/to/your_local_model_path2"
python scripts/vllm_fusion_caption.py \
--model_path ${LLM_MODEL_PATH} \
--input_csv "./examples/test_result.csv" \
--out_csv "./examples/test_result_caption.csv" \
--bs 4 \
--tp 1 \
--task t2v
```
> **Note**:
> - If you want to get i2v caption, just change the `--task t2v` to `--task i2v` in your Command.
#### Gradio Web Demo
Launch the Gradio web demo for SkyCaptioner-V1:
```shell
export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
python scripts/gradio_struct_caption.py \
--skycaptioner_model_path ${SkyCaptioner_V1_Model_PATH}
```
Launch the Gradio web demo for Caption Fusion:
```shell
export LLM_MODEL_PATH="/path/to/your_local_model_path2"
python scripts/gradio_fusion_caption.py \
--fusioncaptioner_model_path ${LLM_MODEL_PATH} \
```
## Acknowledgements
We would like to thank the contributors of <a href="https://github.com/QwenLM/Qwen2.5-VL">Qwen2.5-VL</a>, <a href="https://github.com/bytedance/tarsier">tarsier2</a> and <a href="https://github.com/vllm-project/vllm">vllm</a> repositories, for their open research and contributions.
## Citation
```bibtex
@misc{chen2025skyreelsv2infinitelengthfilmgenerative,
author = {Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
title = {Skyreels V2:Infinite-Length Film Generative Model},
year = {2025},
eprint={2504.13074},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.13074}
}
```
================================================
FILE: skycaptioner_v1/examples/test.csv
================================================
path
./examples/data/1.mp4
./examples/data/2.mp4
./examples/data/3.mp4
./examples/data/4.mp4
================================================
FILE: skycaptioner_v1/examples/test_result.csv
================================================
path,structural_caption
./examples/data/1.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Wearing winter sports gear, including a helmet and goggles."", ""action"": ""The video shows a snowy mountain landscape with a ski slope surrounded by dense trees and distant mountains. A skier is seen descending the slope, moving from the top left to the bottom right of the frame. The skier maintains a steady pace, navigating the curves of the slope. The background includes a ski lift with chairs moving along the cables, and the slope is marked with red and white lines indicating the path for skiers. The skier continues to descend, gradually getting closer to the bottom of the slope."", ""expression"": """", ""position"": ""Centered in the frame, moving downwards on the slope."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""White, covering the ground and trees."", ""action"": """", ""expression"": """", ""position"": ""Surrounding the skier, covering the entire visible area."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Plant"", ""sub_type"": ""Tree""}, ""appearance"": ""Tall, evergreen, covered in snow."", ""action"": """", ""expression"": """", ""position"": ""Scattered throughout the scene, both in the foreground and background."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""Snow-covered, with ski lifts and structures."", ""action"": """", ""expression"": """", ""position"": ""In the background, providing context to the location."", ""is_main_subject"": false}], ""shot_type"": ""long_shot"", ""shot_angle"": ""high_angle"", ""shot_position"": ""overlooking_view"", ""camera_motion"": ""the camera moves toward zooms in"", ""environment"": ""A snowy mountain landscape with a ski slope, trees, and a ski resort in the background."", ""lighting"": ""Bright daylight, casting shadows and highlighting the snow's texture.""}"
./examples/data/2.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Human"", ""sub_type"": ""Woman""}, ""appearance"": ""Long, straight black hair, wearing a sparkling choker necklace with a diamond-like texture, light-colored top, subtle makeup with pink lipstick, stud earrings."", ""action"": ""A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her. The background outside the car is green, indicating a possible outdoor setting."", ""expression"": ""The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor."", ""position"": ""Seated in the foreground of the car, facing slightly to the right."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Human"", ""sub_type"": ""Man""}, ""appearance"": ""Short hair, wearing a dark-colored suit with a white shirt."", ""action"": """", ""expression"": """", ""position"": ""Seated in the background of the car, facing the woman."", ""is_main_subject"": false}], ""shot_type"": ""close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""side_view"", ""camera_motion"": """", ""environment"": ""Interior of a car with dark upholstery."", ""lighting"": ""Soft and natural lighting, suggesting daytime.""}"
./examples/data/3.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Animal"", ""sub_type"": ""Insect""}, ""appearance"": ""The spider has a spherical, yellowish-green body with darker green stripes and spots. It has eight slender legs with visible joints and segments."", ""action"": ""A spider with a yellow and green body and black and white striped legs is hanging from its web in a natural setting with a blurred background of green and brown hues. The spider remains mostly still, with slight movements in its legs and body, indicating a gentle swaying motion."", ""expression"": """", ""position"": ""The spider is centrally positioned in the frame, hanging from a web."", ""is_main_subject"": true}], ""shot_type"": ""extreme_close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""front_view"", ""camera_motion"": """", ""environment"": ""The background consists of vertical, out-of-focus lines in shades of green and brown, suggesting a natural environment with vegetation."", ""lighting"": ""The lighting is soft and diffused, with no harsh shadows, indicating an overcast sky or a shaded area.""}"
./examples/data/4.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Football""}, ""appearance"": ""Wearing a dark-colored jersey, black shorts, and bright blue soccer shoes with white soles."", ""action"": ""A man is on a grassy field with orange cones placed in a line. He is wearing a gray shirt, black shorts, and black socks with blue shoes. The man starts by standing still with his feet apart, then begins to move forward while keeping his eyes on the cones. He continues to run forward, maintaining his focus on the cones, and his feet move in a coordinated manner to navigate around them. The background shows a clear sky, some trees, and a few buildings in the distance."", ""expression"": """", ""position"": ""Centered in the frame, with the soccer ball positioned between the cones."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Bright orange, conical shape."", ""action"": """", ""expression"": """", ""position"": ""Placed on the grass, creating a path for the soccer player."", ""is_main_subject"": false}], ""shot_type"": ""full_shot"", ""shot_angle"": ""low_angle"", ""shot_position"": ""front_view"", ""camera_motion"": ""use a tracking shot, the camera moves toward zooms in"", ""environment"": ""Outdoor sports field with well-maintained grass, trees, and a clear blue sky."", ""lighting"": ""Bright and natural, suggesting a sunny day.""}"
================================================
FILE: skycaptioner_v1/infer_fusion_caption.sh
================================================
expor LLM_MODEL_PATH="/path/to/your_local_model_path2"
python scripts/vllm_fusion_caption.py \
--model_path ${LLM_MODEL_PATH} \
--input_csv "./examples/test_result.csv" \
--out_csv "./examples/test_result_caption.csv" \
--bs 4 \
--tp 1 \
--task t2v
================================================
FILE: skycaptioner_v1/infer_struct_caption.sh
================================================
expor SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
python scripts/vllm_struct_caption.py \
--model_path ${SkyCaptioner_V1_Model_PATH} \
--input_csv "./examples/test.csv" \
--out_csv "./examepls/test_result.csv" \
--tp 1 \
--bs 32
================================================
FILE: skycaptioner_v1/requirements.txt
================================================
decord==0.6.0
transformers>=4.49.0
vllm==0.8.4
================================================
FILE: skycaptioner_v1/scripts/gradio_fusion_caption.py
================================================
import json
import argparse
import pandas as pd
import gradio as gr
from vllm import LLM, SamplingParams
from vllm_fusion_caption import StructuralCaptionDataset
parser = argparse.ArgumentParser()
parser.add_argument("--fusioncaptioner_model_path", default=None, type=str)
parser.add_argument("--tensor_parallel_size", type=int, default=2)
args = parser.parse_args()
example_input = """
{
"subjects": [
{
"TYPES": {
"type": "Human",
"sub_type": "Woman"
},
"appearance": "Long, straight black hair with bangs, wearing a sparkling choker necklace and a dark-colored top or dress with a visible strap over her shoulder.",
"action": "A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her.",
"expression": "The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor.",
"position": "Seated in the foreground of the car, facing slightly to the right.",
"is_main_subject": true
},
{
"TYPES": {
"type": "Human",
"sub_type": "Man"
},
"appearance": "Short hair, wearing a dark-colored suit with a white shirt.",
"action": "",
"expression": "",
"position": "Seated in the background of the car, facing the woman.",
"is_main_subject": false
}
],
"shot_type": "close_up",
"shot_angle": "eye_level",
"shot_position": "side_view",
"camera_motion": "",
"environment": "Interior of a car with a dark color scheme.",
"lighting": "Soft and natural lighting, suggesting daytime."
}
"""
class FusionCaptioner:
def __init__(self, model_path, tensor_parallel_size):
self.model = LLM(model=model_path,
gpu_memory_utilization=0.9,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size)
self.sampling_params = SamplingParams(
temperature=0.1,
max_tokens=512,
stop=['\n\n']
)
self.model_path = model_path
def __call__(self, structural_caption, task='t2v'):
if isinstance(structural_caption, dict):
structural_caption = json.dumps(structural_caption, ensure_ascii=False)
else:
structural_caption = json.dumps(json.loads(structural_caption), ensure_ascii=False)
meta = pd.DataFrame([structural_caption], columns=['structural_caption'])
print(f'structural_caption: {structural_caption}')
print(f'task: {task}')
dataset = StructuralCaptionDataset(meta, self.model_path, task)
_, fusion_by_llm, text, original_text, camera_movement = dataset[0]
llm_original_texts = []
if not fusion_by_llm:
caption = original_text + " " + camera_movement
return caption
try:
outputs = self.model.generate([text], self.sampling_params, use_tqdm=False)
result = outputs[0].outputs[0].text
except Exception as e:
result = llm_original_texts
llm_caption = result + " " + camera_movement
return llm_caption
def main():
fusion_captioner = FusionCaptioner(args.fusioncaptioner_model_path, args.tensor_parallel_size)
def fusion_caption(structural_caption, task):
caption = fusion_captioner(structural_caption, task)
return caption
with gr.Blocks() as demo:
gr.Markdown(
"""
<h1 style="text-align: center; font-size: 2em;">SkyCaptioner</h1>
""",
elem_id="header"
)
with gr.Row():
with gr.Column(visible=True):
with gr.Row():
json_input = gr.Code(
label="Structural Caption",
language="json",
lines=25,
interactive=True
)
with gr.Row():
task_input = gr.Radio(
label="Task",
choices=["t2v", "i2v"],
value="t2v",
interactive=True
)
with gr.Column(visible=True):
text_output = gr.Textbox(
label="Fusion Caption",
lines=25,
interactive=False,
autoscroll=True
)
gr.Button("Generate").click(
fn=fusion_caption,
inputs=[json_input, task_input],
outputs=text_output
)
with gr.Row():
gr.Examples(
examples=[
[example_input, "t2v"],
],
inputs=[json_input, task_input],
label="Example Input"
)
demo.launch(
server_name="0.0.0.0",
server_port=7863,
share=False
)
if __name__ == '__main__':
main()
================================================
FILE: skycaptioner_v1/scripts/gradio_struct_caption.py
================================================
import json
import argparse
import pandas as pd
import gradio as gr
from vllm import LLM, SamplingParams
from vllm_struct_caption import VideoTextDataset
class StructCaptioner:
def __init__(self, model_path, tensor_parallel_size):
self.model = LLM(model=model_path,
gpu_memory_utilization=0.6,
max_model_len=31920,
tensor_parallel_size=tensor_parallel_size)
self.model_path = model_path
self.sampling_params = SamplingParams(temperature=0.05, max_tokens=2048)
def __call__(self, video_path):
meta = pd.DataFrame([video_path], columns=['path'])
dataset = VideoTextDataset(meta, self.model_path)
item = dataset[0]['input']
batch_user_inputs = [{
'prompt': item['prompt'],
'multi_modal_data':{'video': item['multi_modal_data']['video'][0]},
}]
outputs = self.model.generate(batch_user_inputs, self.sampling_params, use_tqdm=False)
caption = outputs[0].outputs[0].text
caption = json.loads(caption)
caption = json.dumps(caption, indent=4, ensure_ascii=False)
return caption
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--skycaptioner_model_path", required=True, type=str)
parser.add_argument("--tensor_parallel_size", type=int, default=2)
args = parser.parse_args()
struct_captioner = StructCaptioner(args.skycaptioner_model_path, args.tensor_parallel_size)
def generate_caption(video_path):
caption = struct_captioner(video_path)
return caption
with gr.Blocks() as demo:
gr.Markdown(
"""
<h1 style="text-align: center; font-size: 2em;">SkyCaptioner</h1>
""",
elem_id="header"
)
with gr.Row():
with gr.Column(visible=True, scale=0.5):
with gr.Row():
video_input = gr.Video(
label="Upload Video",
interactive=True,
format="mp4",
)
with gr.Column(visible=True):
json_output = gr.Code(
label="Caption",
language="json",
lines=25,
interactive=False
)
gr.Button("Generate").click(
fn=generate_caption,
inputs=video_input,
outputs=json_output
)
gr.Examples(
examples=[
["./examples/data/1.mp4"],
["./examples/data/2.mp4"],
],
inputs=video_input,
label="Example Videos"
)
demo.launch(
server_name="0.0.0.0",
server_port=7862,
share=False
)
if __name__ == '__main__':
main()
================================================
FILE: skycaptioner_v1/scripts/utils.py
================================================
import numpy as np
import pandas as pd
def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
flat_indices = []
for x in zip(indices_list):
flat_indices.extend(x)
flat_results = []
for x in zip(result_list):
flat_results.extend(x)
flat_indices = np.array(flat_indices)
flat_results = np.array(flat_results)
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
meta = meta.loc[unique_indices]
return meta
================================================
FILE: skycaptioner_v1/scripts/vllm_fusion_caption.py
================================================
import os
from pathlib import Path
import argparse
import glob
import time
import gc
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
import pandas as pd
from vllm import LLM, SamplingParams
from torch.utils.data import DataLoader
import json
import random
from utils import result_writer
SYSTEM_PROMPT_I2V = """
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
## Structured Input
{structured_input}
## Notes
1. If there has an empty field, just ignore it and do not mention it in the output.
2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field.
## Output Principles and Orders
1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information.
2. Second, describe each subject with its pure action and expression if these fields exist.
## Output
Please directly output the final composed caption without any additional information.
"""
SYSTEM_PROMPT_T2V = """
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
## Structured Input
{structured_input}
## Notes
1. According to the action field information, change its name field to the subject pronoun in the action.
2. If there has an empty field, just ignore it and do not mention it in the output.
3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
## Output Principles and Orders
1. First, declare the shot_type, then declare the shot_angle and the shot_position fields.
2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty.
3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist.
4. Finally, declare the environment and lighting if the environment and lighting fields are not empty.
## Output
Please directly output the final composed caption without any additional information.
"""
SHOT_TYPE_LIST = [
'close-up shot',
'extreme close-up shot',
'medium shot',
'long shot',
'full shot',
]
class StructuralCaptionDataset(torch.utils.data.Dataset):
def __init__(self, input_csv, model_path, task=None):
if isinstance(input_csv, pd.DataFrame):
self.meta = input_csv
else:
self.meta = pd.read_csv(input_csv)
if task is None:
self.task = args.task
else:
self.task = task
self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
def __len__(self):
return len(self.meta)
def __getitem__(self, index):
row = self.meta.iloc[index]
real_index = self.meta.index[index]
struct_caption = json.loads(row["structural_caption"])
camera_movement = struct_caption.get('camera_motion', '')
if camera_movement != '':
camera_movement += '.'
camera_movement = camera_movement.capitalize()
fusion_by_llm = False
cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task)
if cleaned_struct_caption.get('num_subjects', 0) > 0:
new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False)
conversation = [
{
"role": "system",
"content": self.system_prompt.format(structured_input=new_struct_caption),
},
]
text = self.tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
fusion_by_llm = True
else:
text = '-'
return real_index, fusion_by_llm, text, '-', camera_movement
def clean_struct_caption(self, struct_caption, task):
raw_subjects = struct_caption.get('subjects', [])
subjects = []
for subject in raw_subjects:
subject_type = subject.get("TYPES", {}).get('type', '')
subject_sub_type = subject.get("TYPES", {}).get('sub_type', '')
if subject_type not in ["Human", "Animal"]:
subject['expression'] = ''
if subject_type == 'Human' and subject_sub_type == 'Accessory':
subject['expression'] = ''
if subject_sub_type != '':
subject['name'] = subject_sub_type
if 'TYPES' in subject:
del subject['TYPES']
if 'is_main_subject' in subject:
del subject['is_main_subject']
subjects.append(subject)
to_del_subject_ids = []
for idx, subject in enumerate(subjects):
action = subject.get('action', '').strip()
subject['action'] = action
if random.random() > 0.9 and 'appearance' in subject:
del subject['appearance']
if random.random() > 0.9 and 'position' in subject:
del subject['position']
if task == 'i2v':
# just keep name and action, expression in subjects
dropped_keys = ['appearance', 'position']
for key in dropped_keys:
if key in subject:
del subject[key]
if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''):
to_del_subject_ids.append(idx)
# delete the subjects according to the to_del_subject_ids
for idx in sorted(to_del_subject_ids, reverse=True):
del subjects[idx]
shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
# if shot_type not in SHOT_TYPE_LIST:
# struct_caption['shot_type'] = ''
new_struct_caption = {
'num_subjects': len(subjects),
'subjects': subjects,
'shot_type': struct_caption.get('shot_type', ''),
'shot_angle': struct_caption.get('shot_angle', ''),
'shot_position': struct_caption.get('shot_position', ''),
'environment': struct_caption.get('environment', ''),
'lighting': struct_caption.get('lighting', ''),
}
if task == 't2v' and random.random() > 0.9:
del new_struct_caption['lighting']
if task == 'i2v':
drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position']
for drop_key in drop_keys:
del new_struct_caption[drop_key]
return new_struct_caption
def custom_collate_fn(batch):
real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch)
return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Caption Fusion by LLM")
parser.add_argument("--input_csv", default="./examples/test_result.csv")
parser.add_argument("--out_csv", default="./examples/test_result_caption.csv")
parser.add_argument("--bs", type=int, default=4)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--model_path", required=True, type=str, help="LLM model path")
parser.add_argument("--task", default='t2v', help="t2v or i2v")
args = parser.parse_args()
sampling_params = SamplingParams(
temperature=0.1,
max_tokens=512,
stop=['\n\n']
)
# model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/"
llm = LLM(
model=args.model_path,
gpu_memory_utilization=0.9,
max_model_len=4096,
tensor_parallel_size = args.tp
)
dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=8,
collate_fn=custom_collate_fn,
shuffle=False,
drop_last=False,
)
indices_list = []
result_list = []
for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader):
llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], []
for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements):
if fusion_by_llm:
llm_indices.append(idx)
llm_texts.append(text)
llm_original_texts.append(original_text)
llm_camera_movements.append(camera_movement)
else:
indices_list.append(idx)
caption = original_text + " " + camera_movement
result_list.append(caption)
if len(llm_texts) > 0:
try:
outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False)
results = []
for output in outputs:
result = output.outputs[0].text.strip()
results.append(result)
indices_list.extend(llm_indices)
except Exception as e:
print(f"Error at {llm_indices}: {str(e)}")
indices_list.extend(llm_indices)
results = llm_original_texts
for result, camera_movement in zip(results, llm_camera_movements):
# concat camera movement to fusion_caption
llm_caption = result + " " + camera_movement
result_list.append(llm_caption)
torch.cuda.empty_cache()
gc.collect()
gathered_list = [indices_list, result_list]
meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"])
meta_new.to_csv(args.out_csv, index=False)
================================================
FILE: skycaptioner_v1/scripts/vllm_struct_caption.py
================================================
import torch
import decord
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoProcessor
from torch.utils.data import DataLoader
SYSTEM_PROMPT = "I need you to generate a structured and detailed caption for the provided video. The structured output and the requirements for each field are as shown in the following JSON content: {\"subjects\": [{\"appearance\": \"Main subject appearance description\", \"action\": \"Main subject action\", \"expression\": \"Main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Subject position in the video (Can be relative position to other objects or spatial description)\", \"TYPES\": {\"type\": \"Main category (e.g., Human)\", \"sub_type\": \"Sub-category (e.g., Man)\"}, \"is_main_subject\": true}, {\"appearance\": \"Non-main subject appearance description\", \"action\": \"Non-main subject action\", \"expression\": \"Non-main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Position of non-main subject 1\", \"TYPES\": {\"type\": \"Main category (e.g., Vehicles)\", \"sub_type\": \"Sub-category (e.g., Ship)\"}, \"is_main_subject\": false}], \"shot_type\": \"Shot type(Options: long_shot/full_shot/medium_shot/close_up/extreme_close_up/other)\", \"shot_angle\": \"Camera angle(Options: eye_level/high_angle/low_angle/other)\", \"shot_position\": \"Camera position(Options: front_view/back_view/side_view/over_the_shoulder/overhead_view/point_of_view/aerial_view/overlooking_view/other)\", \"camera_motion\": \"Camera movement description\", \"environment\": \"Video background/environment description\", \"lighting\": \"Lighting information in the video\"}"
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, csv_path, model_path):
if isinstance(csv_path, pd.DataFrame):
self.meta = csv_path
else:
self.meta = pd.read_csv(csv_path)
self._path = 'path'
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.processor = AutoProcessor.from_pretrained(model_path)
def __getitem__(self, index):
row = self.meta.iloc[index]
path = row[self._path]
real_index = self.meta.index[index]
vr = decord.VideoReader(path, ctx=decord.cpu(0), width=360, height=420)
start = 0
end = len(vr)
# avg_fps = vr.get_avg_fps()
index = self.get_index(end-start, 16, st=start)
frames = vr.get_batch(index).asnumpy() # n h w c
video_inputs = [torch.from_numpy(frames).permute(0, 3, 1, 2)]
conversation = {
"role": "user",
"content": [
{
"type": "video",
"video": row['path'],
"max_pixels": 360 * 420, # 460800
"fps": 2.0,
},
{
"type": "text",
"text": SYSTEM_PROMPT
},
],
}
# 生成 user_input
user_input = self.processor.apply_chat_template(
[conversation],
tokenize=False,
add_generation_prompt=True
)
results = dict()
inputs = {
'prompt': user_input,
'multi_modal_data': {'video': video_inputs}
}
results["index"] = real_index
results['input'] = inputs
return results
def __len__(self):
return len(self.meta)
def get_index(self, video_size, num_frames, st=0):
seg_size = max(0., float(video_size - 1) / num_frames)
max_frame = int(video_size) - 1
seq = []
# index from 1, must add 1
for i in range(num_frames):
start = int(np.round(seg_size * i))
# end = int(np.round(seg_size * (i + 1)))
idx = min(start, max_frame)
seq.append(idx+st)
return seq
def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
flat_indices = []
for x in zip(indices_list):
flat_indices.extend(x)
flat_results = []
for x in zip(result_list):
flat_results.extend(x)
flat_indices = np.array(flat_indices)
flat_results = np.array(flat_results)
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
meta = meta.loc[unique_indices]
return meta
def worker_init_fn(worker_id):
# Set different seed for each worker
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
# Prevent deadlocks by setting timeout
torch.set_num_threads(1)
def main():
parser = argparse.ArgumentParser(description="SkyCaptioner-V1 vllm batch inference")
parser.add_argument("--input_csv", default="./examples/test.csv")
parser.add_argument("--out_csv", default="./examples/test_result.csv")
parser.add_argument("--bs", type=int, default=4)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--model_path", required=True, type=str, help="skycaptioner-v1 model path")
args = parser.parse_args()
dataset = VideoTextDataset(csv_path=args.input_csv, model_path=args.model_path)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=4,
worker_init_fn=worker_init_fn,
persistent_workers=True,
timeout=180,
)
sampling_params = SamplingParams(temperature=0.05, max_tokens=2048)
llm = LLM(model=args.model_path,
gpu_memory_utilization=0.6,
max_model_len=31920,
tensor_parallel_size=args.tp)
indices_list = []
caption_save = []
for video_batch in tqdm(dataloader):
indices = video_batch["index"]
inputs = video_batch["input"]
batch_user_inputs = []
for prompt, video in zip(inputs['prompt'], inputs['multi_modal_data']['video'][0]):
usi={'prompt':prompt, 'multi_modal_data':{'video':video}}
batch_user_inputs.append(usi)
outputs = llm.generate(batch_user_inputs, sampling_params, use_tqdm=False)
struct_outputs = [output.outputs[0].text for output in outputs]
indices_list.extend(indices.tolist())
caption_save.extend(struct_outputs)
meta_new = result_writer(indices_list, caption_save, dataset.meta, column=["structural_caption"])
meta_new.to_csv(args.out_csv, index=False)
print(f'Saved structural_caption to {args.out_csv}')
if __name__ == '__main__':
main()
================================================
FILE: skyreels_v2_infer/__init__.py
================================================
from .pipelines import DiffusionForcingPipeline
================================================
FILE: skyreels_v2_infer/distributed/__init__.py
================================================
================================================
FILE: skyreels_v2_infer/distributed/xdit_context_parallel.py
================================================
import numpy as np
import torch
import torch.amp as amp
from torch.backends.cuda import sdp_kernel
from xfuser.core.distributed import get_sequence_parallel_rank
from xfuser.core.distributed import get_sequence_parallel_world_size
from xfuser.core.distributed import get_sp_group
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.transformer import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@amp.autocast("cuda", enabled=False)
def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
grid = [grid_sizes.tolist()] * x.size(0)
for i, (f, h, w) in enumerate(grid):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def broadcast_should_calc(should_calc: bool) -> bool:
import torch.distributed as dist
device = torch.cuda.current_device()
int_should_calc = 1 if should_calc else 0
tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
dist.broadcast(tensor, src=0)
should_calc = tensor.item() == 1
return should_calc
def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == "i2v":
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = torch.cat([x, y], dim=1)
# embeddings
x = self.patch_embedding(x)
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
x = x.flatten(2).transpose(1, 2)
if self.flag_causal_attention:
frame_num = grid_sizes[0]
height = grid_sizes[1]
width = grid_sizes[2]
block_num = frame_num // self.num_frame_per_block
range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
# time embeddings
with amp.autocast("cuda", dtype=torch.float32):
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).float()
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
if _flag_df:
e = e.view(b, f, 1, 1, self.dim)
e0 = e0.view(b, f, 1, 1, 6, self.dim)
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
e0 = e0.transpose(1, 2).contiguous()
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context = self.text_embedding(context)
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
if e0.ndim == 4:
e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
if self.enable_teacache:
modulated_inp = e0 if self.use_ref_steps else e
# teacache
if self.cnt % 2 == 0: # even -> conditon
self.is_even = True
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.is_even = False
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
if self.enable_teacache:
if self.is_even:
should_calc_even = broadcast_should_calc(should_calc_even)
if not should_calc_even:
x += self.previous_residual_even
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
ori_x.mul_(-1)
ori_x.add_(x)
self.previous_residual_even = ori_x
else:
should_calc_odd = broadcast_should_calc(should_calc_odd)
if not should_calc_odd:
x += self.previous_residual_odd
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
ori_x.mul_(-1)
ori_x.add_(x)
self.previous_residual_odd = ori_x
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
else:
# Context Parallel
for block in self.blocks:
x = block(x, **kwargs)
# head
if e.ndim == 3:
e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
x = self.head(x, e)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x.float()
def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
x = x.to(self.q.weight.dtype)
q, k, v = qkv_fn(x)
if not self._flag_ar_attention:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
else:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
# x = torch.nn.functional.scaled_dot_product_attention(
# q.transpose(1, 2),
# k.transpose(1, 2),
# v.transpose(1, 2),
# ).transpose(1, 2).contiguous()
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
x = (
torch.nn.functional.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
================================================
FILE: skyreels_v2_infer/modules/__init__.py
================================================
import gc
import os
import torch
from safetensors.torch import load_file
from .clip import CLIPModel
from .t5 import T5EncoderModel
from .transformer import WanModel
from .vae import WanVAE
def download_model(model_id):
if not os.path.exists(model_id):
from huggingface_hub import snapshot_download
model_id = snapshot_download(repo_id=model_id)
return model_id
def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
vae = WanVAE(model_path).to(device).to(weight_dtype)
vae.vae.requires_grad_(False)
vae.vae.eval()
gc.collect()
torch.cuda.empty_cache()
return vae
def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
config_path = os.path.join(model_path, "config.json")
transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
for file in os.listdir(model_path):
if file.endswith(".safetensors"):
file_path = os.path.join(model_path, file)
state_dict = load_file(file_path)
transformer.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
torch.cuda.empty_cache()
transformer.requires_grad_(False)
transformer.eval()
gc.collect()
torch.cuda.empty_cache()
return transformer
def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
gc.collect()
torch.cuda.empty_cache()
return text_encoder
def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
image_enc.requires_grad_(False)
image_enc.eval()
gc.collect()
torch.cuda.empty_cache()
return image_enc
================================================
FILE: skyreels_v2_infer/modules/attention.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
"flash_attention",
"attention",
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == "cuda" and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
q = half(q.flatten(0, 1))
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
# preprocess key, value
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}")
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic,
)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
).unflatten(0, (b, lq))
torch.cuda.nvtx.range_pop()
# output
return x
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
)
out = out.transpose(1, 2).contiguous()
return out
================================================
FILE: skyreels_v2_infer/modules/clip.py
================================================
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from diffusers.models import ModelMixin
from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
__all__ = [
"XLMRobertaCLIP",
"clip_xlm_roberta_vit_h_14",
"CLIPModel",
]
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat(
[
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode="bicubic",
align_corners=False,
)
.flatten(2)
.transpose(1, 2),
],
dim=1,
)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation="quick_gelu",
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5,
):
assert activation in ["quick_gelu", "gelu", "swi_glu"]
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(proj_dropout),
)
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(proj_dropout),
)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(
self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type="token",
pre_norm=True,
post_norm=False,
activation="quick_gelu",
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5,
):
if image_size % patch_size != 0:
print("[WARNING] image_size is not divisible by patch_size", flush=True)
assert pool_type in ("token", "token_fc", "attn_pool")
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
if pool_type in ("token", "token_fc"):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(
gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
)
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(
*[
AttentionBlock(
dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps
)
for _ in range(num_layers)
]
)
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == "token":
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == "token_fc":
self.head = nn.Linear(dim, out_dim)
elif pool_type == "attn_pool":
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ("token", "token_fc"):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop("out_dim")
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)
)
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool="token",
vision_pre_norm=True,
vision_post_norm=False,
activation="gelu",
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5,
):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps,
)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout,
)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [
{
"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
"weight_decay": 0.0,
},
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
]
return groups
def _clip(
pretrained=False,
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding="eos",
dtype=torch.float32,
device="cpu",
**kwargs,
):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if "siglip" in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose(
[
T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
]
)
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool="token",
activation="gelu",
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel(ModelMixin):
def __init__(self, checkpoint_path, tokenizer_path):
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
super().__init__()
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False, return_transforms=True, return_tokenizer=False
)
self.model = self.model.eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace"
)
def encode_video(self, video):
# preprocess
b, c, t, h, w = video.shape
video = video.transpose(1, 2)
video = video.reshape(b * t, c, h, w)
size = (self.model.image_size,) * 2
video = F.interpolate(
video,
size=size,
mode='bicubic',
align_corners=False)
video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
out = self.model.visual(video, use_31_block=True)
return out
================================================
FILE: skyreels_v2_infer/modules/t5.py
================================================
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from .tokenizers import HuggingfaceTokenizer
__all__ = [
"T5Model",
"T5Encoder",
"T5Decoder",
"T5EncoderModel",
]
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
return x
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5Model):
nn.init.normal_(m.token_embedding.weight, std=1.0)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
# output
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
super(T5CrossAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def forward(self, lq, lk):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
super(T5Encoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
super(T5Decoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.size()
# causal mask
if mask is None:
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
elif mask.ndim == 2:
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(
self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.1,
):
super(T5Model, self).__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(
self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
)
self.decoder = T5Decoder(
self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
)
self.head = nn.Linear(dim, vocab_size, bias=False)
# initialize weights
self.apply(init_weights)
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def _t5(
name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
dtype=torch.float32,
device="cpu",
**kwargs,
):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs["vocab"] = kwargs.pop("vocab_size")
kwargs["num_layers"] = kwargs.pop("encoder_layers")
_ = kwargs.pop("decoder_layers")
elif decoder_only:
model_cls = T5Decoder
kwargs["vocab"] = kwargs.pop("vocab_size")
kwargs["num_layers"] = kwargs.pop("decoder_layers")
_ = kwargs.pop("encoder_layers")
else:
model_cls = T5Model
# init model
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.1,
)
cfg.update(**kwargs)
return _t5("umt5-xxl", **cfg)
class T5EncoderModel(ModelMixin):
def __init__(
self,
checkpoint_path=None,
tokenizer_path=None,
text_len=512,
shard_fn=None,
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
super().__init__()
# init model
model = umt5_xxl(encoder_only=True, return_tokenizer=False)
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
else:
self.model.eval().requires_grad_(False)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
def encode(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.device)
mask = mask.to(self.device)
# seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
context = context * mask.unsqueeze(-1).cuda()
return context
================================================
FILE: skyreels_v2_infer/modules/tokenizers.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ["HuggingfaceTokenizer"]
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop("return_mask", False)
# arguments
_kwargs = {"return_tensors": "pt"}
if self.seq_len is not None:
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == "whitespace":
text = whitespace_clean(basic_clean(text))
elif self.clean == "lower":
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == "canonicalize":
text = canonicalize(basic_clean(text))
return text
================================================
FILE: skyreels_v2_infer/modules/transformer.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import numpy as np
import torch
import torch.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin
from diffusers.configuration_utils import register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.modeling_utils import ModelMixin
from torch.backends.cuda import sdp_kernel
from torch.nn.attention.flex_attention import BlockMask
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import flex_attention
from .attention import flash_attention
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
DISABLE_COMPILE = False # get os env
__all__ = ["WanModel"]
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@amp.autocast("cuda", enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@amp.autocast("cuda", enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
bs = x.size(0)
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
f, h, w = grid_sizes.tolist()
seq_len = f * h * w
# precompute multipliers
x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2))
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
x = torch.view_as_real(x * freqs_i).flatten(3)
return x
@torch.compile(dynamic=True, disable=DISABLE_COMPILE)
def fast_rms_norm(x, weight, eps):
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
x = x.type_as(x) * weight
return x
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return fast_rms_norm(x, self.weight, self.eps)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x)
class WanSelfAttention(nn.Module):
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self._flag_ar_attention = False
def set_ar_attention(self):
self._flag_ar_attention = True
def forward(self, x, grid_sizes, freqs, block_mask):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
x = x.to(self.q.weight.dtype)
q, k, v = qkv_fn(x)
if not self._flag_ar_attention:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
else:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
x = (
torch.nn.functional.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
super().__init__(dim, num_heads, window_size, qk_norm, eps)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img)
# compute attention
x = flash_attention(q, k, v)
# output
x = x.flatten(2)
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
WAN_CROSSATTENTION_CLASSES = {
"t2v_cross_attn": WanT2VCrossAttention,
"i2v_cross_attn": WanI2VCrossAttention,
}
def mul_add(x, y, z):
return x.float() + y.float() * z.float()
def mul_add_add(x, y, z):
return x.float() * (1 + y) + z
mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE)
mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE)
class WanAttentionBlock(nn.Module):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def set_ar_attention(self):
self.self_attn.set_ar_attention()
def forward(
self,
x,
e,
grid_sizes,
freqs,
context,
block_mask,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
if e.dim() == 3:
modulation = self.modulation # 1, 6, dim
with amp.autocast("cuda", dtype=torch.float32):
e = (modulation + e).chunk(6, dim=1)
elif e.dim() == 4:
modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
with amp.autocast("cuda", dtype=torch.float32):
e = (modulation + e).chunk(6, dim=1)
e = [ei.squeeze(1) for ei in e]
# self-attention
out = mul_add_add_compile(self.norm1(x), e[1], e[0])
y = self.self_attn(out, grid_sizes, freqs, block_mask)
with amp.autocast("cuda", dtype=torch.float32):
x = mul_add_compile(x, y, e[2])
# cross-attention & ffn function
def cross_attn_ffn(x, context, e):
dtype = context.dtype
x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype))
with amp.autocast("cuda", dtype=torch.float32):
x = mul_add_compile(x, y, e[5])
return x
x = cross_attn_ffn(x, context, e)
return x.to(torch.bfloat16)
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
with amp.autocast("cuda", dtype=torch.float32):
if e.dim() == 2:
modulation = self.modulation # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3:
modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim),
torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(),
torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim),
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
_no_split_modules = ["WanAttentionBlock"]
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
model_type="t2v",
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
inject_sample_info=False,
eps=1e-6,
):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ["t2v", "i2v"]
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.num_frame_per_block = 1
self.flag_causal_attention = False
self.block_mask = None
self.enable_teacache = False
# embeddings
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
if inject_sample_info:
self.fps_embedding = nn.Embedding(2, dim)
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
self.blocks = nn.ModuleList(
[
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
]
)
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat(
[rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))],
dim=1,
)
if model_type == "i2v":
self.img_emb = MLPProj(1280, dim)
self.gradient_checkpointing = False
self.cpu_offloading = False
self.inject_sample_info = inject_sample_info
# initialize weights
self.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def zero_init_i2v_cross_attn(self):
print("zero init i2v cross attn")
for i in range(self.num_layers):
self.blocks[i].cross_attn.v_img.weight.data.zero_()
self.blocks[i].cross_attn.v_img.bias.data.zero_()
@staticmethod
def _prepare_blockwise_causal_attn_mask(
device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1
) -> BlockMask:
"""
we will divide the token sequence into the following format
[1 latent frame] [1 latent frame] ... [1 latent frame]
We use flexattention to construct the attention mask
"""
total_length = num_frames * frame_seqlen
# we do right padding to get to a multiple of 128
padded_length = math.ceil(total_length / 128) * 128 - total_length
ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device)
for tmp in frame_indices:
ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block
def attention_mask(b, h, q_idx, kv_idx):
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
block_mask = create_block_mask(
attention_mask,
B=None,
H=None,
Q_LEN=total_length + padded_length,
KV_LEN=total_length + padded_length,
_compile=False,
device=device,
)
return block_mask
def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
self.enable_teacache = enable_teacache
print('using teacache')
self.cnt = 0
self.num_steps = num_steps
self.teacache_thresh = teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ref_steps = use_ret_steps
if "I2V" in ckpt_dir:
if use_ret_steps:
if '540P' in ckpt_dir:
self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
if '720P' in ckpt_dir:
self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
self.ret_steps = 5*2
self.cutoff_steps = num_steps*2
else:
if '540P' in ckpt_dir:
self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
if '720P' in ckpt_dir:
self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
self.ret_steps = 1*2
self.cutoff_steps = num_steps*2 - 2
else:
if use_ret_steps:
if '1.3B' in ckpt_dir:
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
if '14B' in ckpt_dir:
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
self.ret_steps = 5*2
self.cutoff_steps = num_steps*2
else:
if '1.3B' in ckpt_dir:
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
if '14B' in ckpt_dir:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1*2
self.cutoff_steps = num_steps*2 - 2
def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == "i2v":
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = torch.cat([x, y], dim=1)
# embeddings
x = self.patch_embedding(x)
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
x = x.flatten(2).transpose(1, 2)
if self.flag_causal_attention:
frame_num = grid_sizes[0]
height = grid_sizes[1]
width = grid_sizes[2]
block_num = frame_num // self.num_frame_per_block
range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
# time embeddings
with amp.autocast("cuda", dtype=torch.float32):
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).float()
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
if _flag_df:
e = e.view(b, f, 1, 1, self.dim)
e0 = e0.view(b, f, 1, 1, 6, self.dim)
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
e0 = e0.transpose(1, 2).contiguous()
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context = self.text_embedding(context)
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
if self.enable_teacache:
modulated_inp = e0 if self.use_ref_steps else e
# teacache
if self.cnt%2==0: # even -> conditon
self.is_even = True
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.is_even = False
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
if self.enable_teacache:
if self.is_even:
if not should_calc_even:
x += self.previous_residual_even
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_even = x - ori_x
else:
if not should_calc_odd:
x += self.previous_residual_odd
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_odd = x - ori_x
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
else:
for block in self.blocks:
x = block(x, **kwargs)
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x.float()
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
bs = x.shape[0]
x = x.view(bs, *grid_sizes, *self.patch_size, c)
x = torch.einsum("bfhwpqrc->bcfphqwr", x)
x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return x
def set_ar_attention(self, causal_block_size):
self.num_frame_per_block = causal_block_size
self.flag_causal_attention = True
for block in self.blocks:
block.set_ar_attention()
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if self.inject_sample_info:
nn.init.normal_(self.fps_embedding.weight, std=0.02)
for m in self.fps_projection.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
nn.init.zeros_(self.fps_projection[-1].weight)
nn.init.zeros_(self.fps_projection[-1].bias)
# init output layer
nn.init.zeros_(self.head.head.weight)
================================================
FILE: skyreels_v2_infer/modules/vae.py
================================================
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
__all__ = [
"WanVAE",
]
CACHE_T = 2
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
gitextract_4tjgy5ni/
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── README.md
├── generate_video.py
├── generate_video_df.py
├── requirements.txt
├── skycaptioner_v1/
│ ├── README.md
│ ├── examples/
│ │ ├── test.csv
│ │ └── test_result.csv
│ ├── infer_fusion_caption.sh
│ ├── infer_struct_caption.sh
│ ├── requirements.txt
│ └── scripts/
│ ├── gradio_fusion_caption.py
│ ├── gradio_struct_caption.py
│ ├── utils.py
│ ├── vllm_fusion_caption.py
│ └── vllm_struct_caption.py
└── skyreels_v2_infer/
├── __init__.py
├── distributed/
│ ├── __init__.py
│ └── xdit_context_parallel.py
├── modules/
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── transformer.py
│ ├── vae.py
│ └── xlm_roberta.py
├── pipelines/
│ ├── __init__.py
│ ├── diffusion_forcing_pipeline.py
│ ├── image2video_pipeline.py
│ ├── prompt_enhancer.py
│ └── text2video_pipeline.py
└── scheduler/
├── __init__.py
└── fm_solvers_unipc.py
SYMBOL INDEX (242 symbols across 20 files)
FILE: generate_video_df.py
function get_video_num_frames_moviepy (line 18) | def get_video_num_frames_moviepy(video_path):
FILE: skycaptioner_v1/scripts/gradio_fusion_caption.py
class FusionCaptioner (line 50) | class FusionCaptioner:
method __init__ (line 51) | def __init__(self, model_path, tensor_parallel_size):
method __call__ (line 63) | def __call__(self, structural_caption, task='t2v'):
function main (line 86) | def main():
FILE: skycaptioner_v1/scripts/gradio_struct_caption.py
class StructCaptioner (line 9) | class StructCaptioner:
method __init__ (line 10) | def __init__(self, model_path, tensor_parallel_size):
method __call__ (line 19) | def __call__(self, video_path):
function main (line 34) | def main():
FILE: skycaptioner_v1/scripts/utils.py
function result_writer (line 4) | def result_writer(indices_list: list, result_list: list, meta: pd.DataFr...
FILE: skycaptioner_v1/scripts/vllm_fusion_caption.py
class StructuralCaptionDataset (line 66) | class StructuralCaptionDataset(torch.utils.data.Dataset):
method __init__ (line 67) | def __init__(self, input_csv, model_path, task=None):
method __len__ (line 80) | def __len__(self):
method __getitem__ (line 83) | def __getitem__(self, index):
method clean_struct_caption (line 114) | def clean_struct_caption(self, struct_caption, task):
function custom_collate_fn (line 177) | def custom_collate_fn(batch):
FILE: skycaptioner_v1/scripts/vllm_struct_caption.py
class VideoTextDataset (line 18) | class VideoTextDataset(torch.utils.data.Dataset):
method __init__ (line 19) | def __init__(self, csv_path, model_path):
method __getitem__ (line 28) | def __getitem__(self, index):
method __len__ (line 70) | def __len__(self):
method get_index (line 73) | def get_index(self, video_size, num_frames, st=0):
function result_writer (line 85) | def result_writer(indices_list: list, result_list: list, meta: pd.DataFr...
function worker_init_fn (line 103) | def worker_init_fn(worker_id):
function main (line 110) | def main():
FILE: skyreels_v2_infer/distributed/xdit_context_parallel.py
function pad_freqs (line 13) | def pad_freqs(original_tensor, target_len):
function rope_apply (line 22) | def rope_apply(x, grid_sizes, freqs):
function broadcast_should_calc (line 63) | def broadcast_should_calc(should_calc: bool) -> bool:
function usp_dit_forward (line 74) | def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
function usp_attn_forward (line 233) | def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
FILE: skyreels_v2_infer/modules/__init__.py
function download_model (line 13) | def download_model(model_id):
function get_vae (line 21) | def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> Wa...
function get_transformer (line 30) | def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat...
function get_text_encoder (line 50) | def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloa...
function get_image_encoder (line 61) | def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bflo...
FILE: skyreels_v2_infer/modules/attention.py
function flash_attention (line 26) | def flash_attention(
function attention (line 132) | def attention(
FILE: skyreels_v2_infer/modules/clip.py
function pos_interpolate (line 23) | def pos_interpolate(pos, seq_len):
class QuickGELU (line 46) | class QuickGELU(nn.Module):
method forward (line 47) | def forward(self, x):
class LayerNorm (line 51) | class LayerNorm(nn.LayerNorm):
method forward (line 52) | def forward(self, x):
class SelfAttention (line 56) | class SelfAttention(nn.Module):
method __init__ (line 57) | def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, pro...
method forward (line 71) | def forward(self, x):
class SwiGLU (line 91) | class SwiGLU(nn.Module):
method __init__ (line 92) | def __init__(self, dim, mid_dim):
method forward (line 102) | def forward(self, x):
class AttentionBlock (line 108) | class AttentionBlock(nn.Module):
method __init__ (line 109) | def __init__(
method forward (line 144) | def forward(self, x):
class AttentionPool (line 154) | class AttentionPool(nn.Module):
method __init__ (line 155) | def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_...
method forward (line 179) | def forward(self, x):
class VisionTransformer (line 202) | class VisionTransformer(nn.Module):
method __init__ (line 203) | def __init__(
method forward (line 268) | def forward(self, x, interpolation=False, use_31_block=False):
class XLMRobertaWithHead (line 292) | class XLMRobertaWithHead(XLMRoberta):
method __init__ (line 293) | def __init__(self, **kwargs):
method forward (line 303) | def forward(self, ids):
class XLMRobertaCLIP (line 316) | class XLMRobertaCLIP(nn.Module):
method __init__ (line 317) | def __init__(
method forward (line 397) | def forward(self, imgs, txt_ids):
method param_groups (line 409) | def param_groups(self):
function _clip (line 420) | def _clip(
function clip_xlm_roberta_vit_h_14 (line 460) | def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-cl...
class CLIPModel (line 488) | class CLIPModel(ModelMixin):
method __init__ (line 489) | def __init__(self, checkpoint_path, tokenizer_path):
method encode_video (line 507) | def encode_video(self, video):
FILE: skyreels_v2_infer/modules/t5.py
function fp16_clamp (line 21) | def fp16_clamp(x):
function init_weights (line 28) | def init_weights(m):
class GELU (line 46) | class GELU(nn.Module):
method forward (line 47) | def forward(self, x):
class T5LayerNorm (line 51) | class T5LayerNorm(nn.Module):
method __init__ (line 52) | def __init__(self, dim, eps=1e-6):
method forward (line 58) | def forward(self, x):
class T5Attention (line 65) | class T5Attention(nn.Module):
method __init__ (line 66) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
method forward (line 81) | def forward(self, x, context=None, mask=None, pos_bias=None):
class T5FeedForward (line 117) | class T5FeedForward(nn.Module):
method __init__ (line 118) | def __init__(self, dim, dim_ffn, dropout=0.1):
method forward (line 129) | def forward(self, x):
class T5SelfAttention (line 137) | class T5SelfAttention(nn.Module):
method __init__ (line 138) | def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, sha...
method forward (line 154) | def forward(self, x, mask=None, pos_bias=None):
class T5CrossAttention (line 161) | class T5CrossAttention(nn.Module):
method __init__ (line 162) | def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, sha...
method forward (line 180) | def forward(self, x, mask=None, encoder_states=None, encoder_mask=None...
class T5RelativeEmbedding (line 188) | class T5RelativeEmbedding(nn.Module):
method __init__ (line 189) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
method forward (line 199) | def forward(self, lq, lk):
method _relative_position_bucket (line 209) | def _relative_position_bucket(self, rel_pos):
class T5Encoder (line 233) | class T5Encoder(nn.Module):
method __init__ (line 234) | def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layer...
method forward (line 259) | def forward(self, ids, mask=None):
class T5Decoder (line 270) | class T5Decoder(nn.Module):
method __init__ (line 271) | def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layer...
method forward (line 296) | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=No...
class T5Model (line 316) | class T5Model(nn.Module):
method __init__ (line 317) | def __init__(
method forward (line 353) | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
function _t5 (line 360) | def _t5(
function umt5_xxl (line 404) | def umt5_xxl(**kwargs):
class T5EncoderModel (line 421) | class T5EncoderModel(ModelMixin):
method __init__ (line 422) | def __init__(
method encode (line 446) | def encode(self, texts):
FILE: skyreels_v2_infer/modules/tokenizers.py
function basic_clean (line 12) | def basic_clean(text):
function whitespace_clean (line 18) | def whitespace_clean(text):
function canonicalize (line 24) | def canonicalize(text, keep_punctuation_exact_string=None):
class HuggingfaceTokenizer (line 38) | class HuggingfaceTokenizer:
method __init__ (line 39) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
method __call__ (line 49) | def __call__(self, sequence, **kwargs):
method _clean (line 71) | def _clean(self, text):
FILE: skyreels_v2_infer/modules/transformer.py
function sinusoidal_embedding_1d (line 26) | def sinusoidal_embedding_1d(dim, position):
function rope_params (line 39) | def rope_params(max_seq_len, dim, theta=10000):
function rope_apply (line 49) | def rope_apply(x, grid_sizes, freqs):
function fast_rms_norm (line 79) | def fast_rms_norm(x, weight, eps):
class WanRMSNorm (line 86) | class WanRMSNorm(nn.Module):
method __init__ (line 87) | def __init__(self, dim, eps=1e-5):
method forward (line 93) | def forward(self, x):
method _norm (line 100) | def _norm(self, x):
class WanLayerNorm (line 104) | class WanLayerNorm(nn.LayerNorm):
method __init__ (line 105) | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
method forward (line 108) | def forward(self, x):
class WanSelfAttention (line 116) | class WanSelfAttention(nn.Module):
method __init__ (line 117) | def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,...
method set_ar_attention (line 137) | def set_ar_attention(self):
method forward (line 140) | def forward(self, x, grid_sizes, freqs, block_mask):
class WanT2VCrossAttention (line 186) | class WanT2VCrossAttention(WanSelfAttention):
method forward (line 187) | def forward(self, x, context):
class WanI2VCrossAttention (line 210) | class WanI2VCrossAttention(WanSelfAttention):
method __init__ (line 211) | def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,...
method forward (line 219) | def forward(self, x, context):
function mul_add (line 254) | def mul_add(x, y, z):
function mul_add_add (line 258) | def mul_add_add(x, y, z):
class WanAttentionBlock (line 266) | class WanAttentionBlock(nn.Module):
method __init__ (line 267) | def __init__(
method set_ar_attention (line 298) | def set_ar_attention(self):
method forward (line 301) | def forward(
class Head (line 347) | class Head(nn.Module):
method __init__ (line 348) | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
method forward (line 363) | def forward(self, x, e):
class MLPProj (line 382) | class MLPProj(torch.nn.Module):
method __init__ (line 383) | def __init__(self, in_dim, out_dim):
method forward (line 394) | def forward(self, image_embeds):
class WanModel (line 399) | class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
method __init__ (line 410) | def __init__(
method _set_gradient_checkpointing (line 531) | def _set_gradient_checkpointing(self, module, value=False):
method zero_init_i2v_cross_attn (line 534) | def zero_init_i2v_cross_attn(self):
method _prepare_blockwise_causal_attn_mask (line 541) | def _prepare_blockwise_causal_attn_mask(
method initialize_teacache (line 578) | def initialize_teacache(self, enable_teacache=True, num_steps=25, teac...
method forward (line 622) | def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
method unpatchify (line 777) | def unpatchify(self, x, grid_sizes):
method set_ar_attention (line 801) | def set_ar_attention(self, causal_block_size):
method init_weights (line 807) | def init_weights(self):
FILE: skyreels_v2_infer/modules/vae.py
class CausalConv3d (line 17) | class CausalConv3d(nn.Conv3d):
method __init__ (line 22) | def __init__(self, *args, **kwargs):
method forward (line 27) | def forward(self, x, cache_x=None):
class RMS_norm (line 38) | class RMS_norm(nn.Module):
method __init__ (line 39) | def __init__(self, dim, channel_first=True, images=True, bias=False):
method forward (line 49) | def forward(self, x):
class Upsample (line 53) | class Upsample(nn.Upsample):
method forward (line 54) | def forward(self, x):
class Resample (line 61) | class Resample(nn.Module):
method __init__ (line 62) | def __init__(self, dim, mode):
method forward (line 88) | def forward(self, x, feat_cache=None, feat_idx=[0]):
method init_weight (line 139) | def init_weight(self, conv):
method init_weight2 (line 151) | def init_weight2(self, conv):
class ResidualBlock (line 163) | class ResidualBlock(nn.Module):
method __init__ (line 164) | def __init__(self, in_dim, out_dim, dropout=0.0):
method forward (line 181) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class AttentionBlock (line 200) | class AttentionBlock(nn.Module):
method __init__ (line 205) | def __init__(self, dim):
method forward (line 217) | def forward(self, x):
class Encoder3d (line 239) | class Encoder3d(nn.Module):
method __init__ (line 240) | def __init__(
method forward (line 292) | def forward(self, x, feat_cache=None, feat_idx=[0]):
class Decoder3d (line 337) | class Decoder3d(nn.Module):
method __init__ (line 338) | def __init__(
method forward (line 390) | def forward(self, x, feat_cache=None, feat_idx=[0]):
function count_conv3d (line 436) | def count_conv3d(model):
class WanVAE_ (line 444) | class WanVAE_(nn.Module):
method __init__ (line 445) | def __init__(
method forward (line 472) | def forward(self, x):
method encode (line 478) | def encode(self, x, scale):
method decode (line 503) | def decode(self, z, scale):
method reparameterize (line 522) | def reparameterize(self, mu, log_var):
method sample (line 527) | def sample(self, imgs, deterministic=False):
method clear_cache (line 534) | def clear_cache(self):
function _video_vae (line 544) | def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
class WanVAE (line 571) | class WanVAE:
method __init__ (line 572) | def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16):
method encode (line 625) | def encode(self, video):
method to (line 631) | def to(self, *args, **kwargs):
method decode (line 638) | def decode(self, z):
FILE: skyreels_v2_infer/modules/xlm_roberta.py
class SelfAttention (line 10) | class SelfAttention(nn.Module):
method __init__ (line 11) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
method forward (line 26) | def forward(self, x, mask):
class AttentionBlock (line 48) | class AttentionBlock(nn.Module):
method __init__ (line 49) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
method forward (line 62) | def forward(self, x, mask):
class XLMRoberta (line 72) | class XLMRoberta(nn.Module):
method __init__ (line 77) | def __init__(
method forward (line 115) | def forward(self, ids):
function xlm_roberta_large (line 143) | def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="...
FILE: skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
class DiffusionForcingPipeline (line 25) | class DiffusionForcingPipeline:
method __init__ (line 40) | def __init__(
method do_classifier_free_guidance (line 80) | def do_classifier_free_guidance(self) -> bool:
method encode_image (line 83) | def encode_image(
method prepare_latents (line 102) | def prepare_latents(
method generate_timestep_matrix (line 111) | def generate_timestep_matrix(
method get_video_as_tensor (line 187) | def get_video_as_tensor(self, video_path, width, height):
method extend_video (line 212) | def extend_video(
method __call__ (line 380) | def __call__(
FILE: skyreels_v2_infer/pipelines/image2video_pipeline.py
function resizecrop (line 20) | def resizecrop(image: Image.Image, th, tw):
class Image2VideoPipeline (line 38) | class Image2VideoPipeline:
method __init__ (line 39) | def __init__(
method __call__ (line 67) | def __call__(
FILE: skyreels_v2_infer/pipelines/prompt_enhancer.py
class PromptEnhancer (line 25) | class PromptEnhancer:
method __init__ (line 26) | def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"):
method __call__ (line 34) | def __call__(self, prompt):
FILE: skyreels_v2_infer/pipelines/text2video_pipeline.py
class Text2VideoPipeline (line 17) | class Text2VideoPipeline:
method __init__ (line 18) | def __init__(
method __call__ (line 45) | def __call__(
FILE: skyreels_v2_infer/scheduler/fm_solvers_unipc.py
class FlowUniPCMultistepScheduler (line 20) | class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 77) | def __init__(
method step_index (line 131) | def step_index(self):
method begin_index (line 138) | def begin_index(self):
method set_begin_index (line 145) | def set_begin_index(self, begin_index: int = 0):
method set_timesteps (line 156) | def set_timesteps(
method _threshold_sample (line 217) | def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
method _sigma_to_t (line 251) | def _sigma_to_t(self, sigma):
method _sigma_to_alpha_sigma_t (line 254) | def _sigma_to_alpha_sigma_t(self, sigma):
method time_shift (line 258) | def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
method convert_model_output (line 261) | def convert_model_output(
method multistep_uni_p_bh_update (line 331) | def multistep_uni_p_bh_update(
method multistep_uni_c_bh_update (line 460) | def multistep_uni_c_bh_update(
method index_for_timestep (line 597) | def index_for_timestep(self, timestep, schedule_timesteps=None):
method _init_step_index (line 612) | def _init_step_index(self, timestep):
method step (line 624) | def step(
method scale_model_input (line 708) | def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> ...
method add_noise (line 724) | def add_noise(
method __len__ (line 758) | def __len__(self):
Condensed preview — 36 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (308K chars).
[
{
"path": ".gitignore",
"chars": 194,
"preview": "__pycache__/\ncheckpoint/*\ncheckpoint\nresults/*\n.DS_Store\nresults/*\n*.png\n*.jpg\n*.mp4\n*.log*\n*.json\nscripts/transformer/*"
},
{
"path": ".pre-commit-config.yaml",
"chars": 684,
"preview": "repos:\n - repo: https://github.com/asottile/reorder-python-imports.git\n rev: v3.8.3\n hooks:\n - id: reorder-p"
},
{
"path": "LICENSE.txt",
"chars": 2389,
"preview": "---\nlanguage:\n - en\n - zh\nlicense: other\ntasks:\n - text-generation\n\n---\n\n<!-- markdownlint-disable first-line-h1 -->\n"
},
{
"path": "README.md",
"chars": 44765,
"preview": "<p align=\"center\">\n <img src=\"assets/logo2.png\" alt=\"SkyReels Logo\" width=\"50%\">\n</p>\n\n<h1 align=\"center\">SkyReels V2: "
},
{
"path": "generate_video.py",
"chars": 6784,
"preview": "import argparse\nimport gc\nimport os\nimport random\nimport time\n\nimport imageio\nimport torch\nfrom diffusers.utils import l"
},
{
"path": "generate_video_df.py",
"chars": 9378,
"preview": "import argparse\nimport gc\nimport os\nimport random\nimport time\n\nimport imageio\nimport torch\nfrom diffusers.utils import l"
},
{
"path": "requirements.txt",
"chars": 221,
"preview": "torch==2.5.1\ntorchvision==0.20.1\nopencv-python==4.10.0.84\ndiffusers>=0.31.0\ntransformers==4.49.0\ntokenizers==0.21.1\nacce"
},
{
"path": "skycaptioner_v1/README.md",
"chars": 9983,
"preview": "# SkyCaptioner-V1: A Structural Video Captioning Model\n\n<p align=\"center\">\n📑 <a href=\"https://arxiv.org/pdf/2504.13074\">"
},
{
"path": "skycaptioner_v1/examples/test.csv",
"chars": 92,
"preview": "path\n./examples/data/1.mp4\n./examples/data/2.mp4\n./examples/data/3.mp4\n./examples/data/4.mp4"
},
{
"path": "skycaptioner_v1/examples/test_result.csv",
"chars": 6036,
"preview": "path,structural_caption\n./examples/data/1.mp4,\"{\"\"subjects\"\": [{\"\"TYPES\"\": {\"\"type\"\": \"\"Sport\"\", \"\"sub_type\"\": \"\"Other\"\""
},
{
"path": "skycaptioner_v1/infer_fusion_caption.sh",
"chars": 274,
"preview": "expor LLM_MODEL_PATH=\"/path/to/your_local_model_path2\"\n\npython scripts/vllm_fusion_caption.py \\\n --model_path ${LLM_M"
},
{
"path": "skycaptioner_v1/infer_struct_caption.sh",
"chars": 265,
"preview": "expor SkyCaptioner_V1_Model_PATH=\"/path/to/your_local_model_path\"\n\npython scripts/vllm_struct_caption.py \\\n --model_p"
},
{
"path": "skycaptioner_v1/requirements.txt",
"chars": 46,
"preview": "decord==0.6.0\ntransformers>=4.49.0\nvllm==0.8.4"
},
{
"path": "skycaptioner_v1/scripts/gradio_fusion_caption.py",
"chars": 5376,
"preview": "import json\nimport argparse\nimport pandas as pd\nimport gradio as gr\n\nfrom vllm import LLM, SamplingParams\n\nfrom vllm_fus"
},
{
"path": "skycaptioner_v1/scripts/gradio_struct_caption.py",
"chars": 2918,
"preview": "import json\nimport argparse\nimport pandas as pd\nimport gradio as gr\nfrom vllm import LLM, SamplingParams\nfrom vllm_struc"
},
{
"path": "skycaptioner_v1/scripts/utils.py",
"chars": 596,
"preview": "import numpy as np\nimport pandas as pd\n\ndef result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, col"
},
{
"path": "skycaptioner_v1/scripts/vllm_fusion_caption.py",
"chars": 10411,
"preview": "import os\nfrom pathlib import Path\nimport argparse\nimport glob\nimport time\nimport gc\nfrom tqdm import tqdm\nimport torch\n"
},
{
"path": "skycaptioner_v1/scripts/vllm_struct_caption.py",
"chars": 6828,
"preview": "\nimport torch\nimport decord\nimport argparse\n\nimport pandas as pd\nimport numpy as np\n\nfrom tqdm import tqdm\nfrom vllm imp"
},
{
"path": "skyreels_v2_infer/__init__.py",
"chars": 48,
"preview": "from .pipelines import DiffusionForcingPipeline\n"
},
{
"path": "skyreels_v2_infer/distributed/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "skyreels_v2_infer/distributed/xdit_context_parallel.py",
"chars": 10869,
"preview": "import numpy as np\nimport torch\nimport torch.amp as amp\nfrom torch.backends.cuda import sdp_kernel\nfrom xfuser.core.dist"
},
{
"path": "skyreels_v2_infer/modules/__init__.py",
"chars": 2312,
"preview": "import gc\nimport os\n\nimport torch\nfrom safetensors.torch import load_file\n\nfrom .clip import CLIPModel\nfrom .t5 import T"
},
{
"path": "skyreels_v2_infer/modules/attention.py",
"chars": 5275,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport torch\n\ntry:\n import flash_attn_interf"
},
{
"path": "skyreels_v2_infer/modules/clip.py",
"chars": 16183,
"preview": "# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''\n# Copyright 2024-2"
},
{
"path": "skyreels_v2_infer/modules/t5.py",
"chars": 15683,
"preview": "# Modified from transformers.models.t5.modeling_t5\n# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserv"
},
{
"path": "skyreels_v2_infer/modules/tokenizers.py",
"chars": 2377,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport html\nimport string\n\nimport ftfy\nimport r"
},
{
"path": "skyreels_v2_infer/modules/transformer.py",
"chars": 30924,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport math\nimport numpy as np\nimport torch\nimp"
},
{
"path": "skyreels_v2_infer/modules/vae.py",
"chars": 21792,
"preview": "# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.\nimport logging\n\nimport torch\nimport torch.nn as"
},
{
"path": "skyreels_v2_infer/modules/xlm_roberta.py",
"chars": 4686,
"preview": "# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta\n# Copyright 2024-2025 The Alibaba Wan Team Authors."
},
{
"path": "skyreels_v2_infer/pipelines/__init__.py",
"chars": 260,
"preview": "from .diffusion_forcing_pipeline import DiffusionForcingPipeline\nfrom .image2video_pipeline import Image2VideoPipeline\nf"
},
{
"path": "skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py",
"chars": 33021,
"preview": "import math\nimport os\nfrom typing import List\nfrom typing import Optional\nfrom typing import Tuple\nfrom typing import Un"
},
{
"path": "skyreels_v2_infer/pipelines/image2video_pipeline.py",
"chars": 5918,
"preview": "import os\nfrom typing import List\nfrom typing import Optional\nfrom typing import Union\n\nimport numpy as np\nimport torch\n"
},
{
"path": "skyreels_v2_infer/pipelines/prompt_enhancer.py",
"chars": 3076,
"preview": "import argparse\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nsys_prompt = \"\"\"\nTransform the short promp"
},
{
"path": "skyreels_v2_infer/pipelines/text2video_pipeline.py",
"chars": 4386,
"preview": "import os\nfrom typing import List\nfrom typing import Optional\nfrom typing import Union\n\nimport numpy as np\nimport torch\n"
},
{
"path": "skyreels_v2_infer/scheduler/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "skyreels_v2_infer/scheduler/fm_solvers_unipc.py",
"chars": 31667,
"preview": "# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep."
}
]
About this extraction
This page contains the full source code of the SkyworkAI/SkyReels-V2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 36 files (288.8 KB), approximately 76.3k tokens, and a symbol index with 242 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.