Repository: Yuanshi9815/Video-Infinity
Branch: main
Commit: acfcd17c25cd
Files: 11
Total size: 41.8 KB
Directory structure:
gitextract_0rdk_80z/
├── README.md
├── examples/
│ ├── config.json
│ ├── multi_promts.json
│ └── single_gpu.json
├── inference.py
├── requirements.txt
└── src/
├── tools.py
├── video_crafter.py
└── video_infinity/
├── __init__.py
├── plugins.py
└── wrapper.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
<div align="center">
# Video-Infinity
<img src='./assets/VideoGen-Main.png' width='80%' />
<br>
<a href="https://arxiv.org/abs/2406.16260"><img src="https://img.shields.io/badge/ariXv-2406.16260-A42C25.svg" alt="arXiv"></a>
<a href="https://video-infinity.tanzhenxiong.com"><img src="https://img.shields.io/badge/ProjectPage-Video Infinity-376ED2#376ED2.svg" alt="arXiv"></a>
</div>
> **Video-Infinity: Distributed Long Video Generation**
> <br>
> Zhenxiong Tan,
> [Xingyi Yang](https://adamdad.github.io/),
> [Songhua Liu](http://121.37.94.87/),
> and
> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
> <br>
> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
> <br>
## TL;DR (Too Long; Didn't Read)
Video-Infinity generates long videos quickly using multiple GPUs without extra training. Feel free to visit our
[project page](https://video-infinity.tanzhenxiong.com)
for more information and generated videos.
## Features
* **Distributed 🌐**: Utilizes multiple GPUs to generate long-form videos.
* **High-Speed 🚀**: Produces 2,300 frames in just 5 minutes.
* **Training-Free 🎓**: Generates long videos without requiring additional training for existing models.
## Setup
### Installation Environment
```bash
conda create -n video_infinity_vc2 python=3.10
conda activate video_infinity_vc2
pip install -r requirements.txt
```
<!-- ### Download Pretrained Models
We provide a diffusers pipeline for [VideoCrafter2](TODO) to generate long videos.
```bash
huggingface-cli download adamdad/videocrafterv2_diffusers
``` -->
## Usage
### Quick Start
- **Basic Usage**
```bash
python inference.py --config examples/config.json
```
- **Multi-Prompts**
```bash
python inference.py --config examples/multi_prompts.json
```
- **Single GPU**
```bash
python inference.py --config examples/single_gpu.json
```
### Config
#### Basic Config
| Parameter | Description |
| ----------- | -------------------------------------- |
| `devices` | The list of GPU devices to use. |
| `base_path` | The path to save the generated videos. |
#### Pipeline Config
| Parameter | Description |
| ------------ | ---------------------------------------------------------------------------------------------------- |
| `prompts` | The list of text prompts. **Note**: The number of prompts should be greater than the number of GPUs. |
| `file_name` | The name of the generated video. |
| `num_frames` | The number of frames to generate on **each GPU**. |
#### Video-Infinity Config
| Parameter | Description |
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `*.padding` | The number of local context frames. |
| `attn.topk` | The number of global context frames for `Attention` model. |
| `attn.local_phase` | When the denoise timestep is less than `t`, it bias the attention. This adds a `local_bias` to the local context frames and a `global_bias` to the global context frames. |
| `attn.global_phase` | It is similar to `local_phase`. But it bias the attention when the denoise timestep is greater than `t`. |
| `attn.token_num_scale` | If the value is `True`, the scale factor will be rescaled by the number of tokens. Default is `False`. More details can be referred to this [paper](https://arxiv.org/abs/2306.08645). |
#### How to Set Config
- To avoid the loss of high-frequency information, we recommend setting the sum of `padding` and `attn.topk` to be less than 24 (which is similar to the number of the default frames in the `VideoCrafter2` model).
- If you wish to have a larger `padding` or `attn.topk`, you should set the `attn.token_num_scale` to `True`.
- A higher `local_phase.t` and `global_phase.t` will result in more stable videos but may reduce the diversity of the videos.
- More `padding` will provide more local context.
- A higher `attn.topk` will bring about overall stability in the videos.
## Citation
```
@article{
tan2024videoinf,
title={Video-Infinity: Distributed Long Video Generation},
author={Zhenxiong Tan, Xingyi Yang, Songhua Liu, and Xinchao Wang},
journal={arXiv preprint arXiv:2406.16260},
year={2024}
}
```
## Acknowledgements
Our project is based on the [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2) model. We would like to thank the authors for their excellent work! ❤️
================================================
FILE: examples/config.json
================================================
{
"dtype": "torch.float16",
"devices": [0,1,2,3],
"seed": 123,
"master_port": 29516,
"base_path": "./exp",
"pipe_configs":{
"prompts": [
"A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background."
],
"steps": 30,
"guidance_scale": 12,
"fps": 24,
"num_frames": 24,
"height": 320,
"width": 512,
"export_fps": 8,
"file_name": null
},
"plugin_configs":{
"attn":{
"padding": 8,
"top_k": 16,
"top_k_chunk_size": 24,
"attn_scale": 1.0,
"token_num_scale": false,
"dynamic_scale": true,
"local_phase": {
"t": 800,
"local_biase": 10,
"global_biase": 0
},
"global_phase": {
"t": 800,
"local_biase": 0,
"global_biase": 10
}
},
"conv_3d": {
"padding": 1
},
"conv_layer": {}
}
}
================================================
FILE: examples/multi_promts.json
================================================
{
"dtype": "torch.float16",
"devices": [
0,
1,
2,
3,
4,
5,
6,
7
],
"seed": 123,
"master_port": 29516,
"base_path": "./exp",
"pipe_configs": {
"prompts": [
"[Ukiyo-e style] A black Akita puppy stands alone under the eaves of a traditional Japanese house, shivering in the rain and looking scared.",
"[Ukiyo-e style] A girl in a blue kimono comforts a shivering black Akita puppy during a rainy walk through a stone-paved village.",
"[Ukiyo-e style] A girl in a blue kimono brings an black Akita puppy into her warm home, a traditional wooden Japanese house with sliding doors.",
"[Ukiyo-e style] A girl in a blue kimono plays with her energetic black Akita puppy in the garden of a traditional Japanese house, throwing a woven ball.",
"[Ukiyo-e style] A teenager in a blue kimono jogs with her black Akita through a park filled with cherry blossom trees and ancient stone lanterns.",
"[Ukiyo-e style] A teenage girl in a blue kimono relaxes in a field of wildflowers, reading a scroll while her black Akita rests beside him under a cherry tree.",
"[Ukiyo-e style] A young girl in a blue kimono celebrates her coming-of-age ceremony with her black Akita, surrounded by festive lanterns and banners.",
"[Ukiyo-e style] A girl in a blue kimono and her loyal black Akita enjoy a serene sunset walk along the beach, with the silhouette of Mount Fuji in the distance."
],
"steps": 30,
"guidance_scale": 12,
"fps": 24,
"num_frames": 24,
"height": 320,
"width": 512,
"export_fps": 8,
"file_name": null
},
"plugin_configs": {
"attn": {
"padding": 8,
"top_k": 16,
"top_k_chunk_size": 24,
"attn_scale": 1.0,
"token_num_scale": false,
"dynamic_scale": true,
"local_phase": {
"t": 850,
"local_biase": 10,
"global_biase": 0
},
"global_phase": {
"t": 850,
"local_biase": 0,
"global_biase": 10
}
},
"conv_3d": {
"padding": 1
},
"conv_layer": {}
}
}
================================================
FILE: examples/single_gpu.json
================================================
{
"dtype": "torch.float16",
"devices": [0],
"seed": 123,
"master_port": 29516,
"base_path": "./exp",
"pipe_configs":{
"prompts": [
"A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background."
],
"steps": 30,
"guidance_scale": 12,
"fps": 24,
"num_frames": 24,
"height": 320,
"width": 512,
"export_fps": 8,
"file_name": null
},
"plugin_configs":{
"attn":{
"padding": 8,
"top_k": 16,
"top_k_chunk_size": 24,
"attn_scale": 1.0,
"token_num_scale": false,
"dynamic_scale": true,
"local_phase": {
"t": 800,
"local_biase": 10,
"global_biase": 0
},
"global_phase": {
"t": 800,
"local_biase": 0,
"global_biase": 10
}
},
"conv_3d": {
"padding": 1
},
"conv_layer": {}
}
}
================================================
FILE: inference.py
================================================
import torch
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import json
import os
from src.video_crafter import VideoCrafterPipeline, UNetVideoCrafter
from diffusers.schedulers import DPMSolverMultistepScheduler
from src.tools import DistController
from src.video_infinity.wrapper import DistWrapper
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Video Infinity Inference")
parser.add_argument("--config", type=str)
args = parser.parse_args()
return args
def init_pipeline(config):
pipe = VideoCrafterPipeline.from_pretrained(
'adamdad/videocrafterv2_diffusers',
torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload(
gpu_id=config["devices"][dist.get_rank() % len(config["devices"])],
)
pipe.enable_vae_slicing()
return pipe
def run_inference(rank, world_size, config):
dist_controller = DistController(rank, world_size, config)
pipe = init_pipeline(config)
dist_pipe = DistWrapper(pipe, dist_controller, config)
start = time.time()
pipe_configs=config['pipe_configs']
plugin_configs=config['plugin_configs']
prompt_id = int(rank / world_size * len(pipe_configs["prompts"]))
prompt = pipe_configs["prompts"][prompt_id]
start = time.time()
dist_pipe.inference(
prompt,
config,
pipe_configs,
plugin_configs,
additional_info={
"full_config": config,
}
)
print(f"Rank {rank} finished. Time: {time.time() - start}")
def main(config):
size = len(config["devices"])
processes = []
if not os.path.exists(config["base_path"]):
os.makedirs(config["base_path"])
for rank, _ in enumerate(config["devices"]):
p = mp.Process(target=run_inference, args=(rank, size, config))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
mp.set_start_method("spawn")
with open(parse_args().config, "r") as f:
config = json.load(f)
main(config)
================================================
FILE: requirements.txt
================================================
torch
diffusers
transformers
imageio
accelerate
ffmpeg
pyav
imageio-ffmpeg
================================================
FILE: src/tools.py
================================================
import json
import numpy as np
import imageio
import os
import torch
import torch.distributed as dist
def export_to_video(video_frames, output_video_path, fps = 12):
# Ensure all frames are NumPy arrays and determine video dimensions from the first frame
assert all(isinstance(frame, np.ndarray) for frame in video_frames), "All video frames must be NumPy arrays."
# Ensure output_video_path is ending with .mp4
if not output_video_path.endswith('.mp4'):
output_video_path += '.mp4'
# Create a video file at the specified path and write frames to it
with imageio.get_writer(output_video_path, fps=fps, format='mp4') as writer:
for frame in video_frames:
writer.append_data(
(frame * 255).astype(np.uint8)
)
def save_generation(video_frames, configs, base_path, file_name=None):
if not os.path.exists(base_path):
os.makedirs(base_path)
p_config = configs["pipe_configs"]
frames, steps, fps = p_config["num_frames"], p_config["steps"], p_config["fps"]
if not file_name:
index = [int(each.split('_')[0]) for each in os.listdir(base_path)]
max_idex = max(index) if index else 0
idx_str = str(max_idex + 1).zfill(6)
key_info = '_'.join([str(frames), str(steps), str(fps)])
file_name = f'{idx_str}_{key_info}'
with open(f'{base_path}/{file_name}.json', 'w') as f:
json.dump(configs, f, indent=4)
export_to_video(video_frames, f'{base_path}/{file_name}.mp4', fps=p_config["export_fps"])
return file_name
class GlobalState:
def __init__(self, state={}) -> None:
self.init_state(state)
def init_state(self, state={}):
self.state = state
def set(self, key, value):
self.state[key] = value
def get(self, key, default=None):
return self.state.get(key, default)
class DistController(object):
def __init__(self, rank, world_size, config) -> None:
super().__init__()
self.rank = rank
self.world_size = world_size
self.config = config
self.is_master = rank == 0
self.init_dist()
self.init_group()
self.device = torch.device(f"cuda:{config['devices'][dist.get_rank()]}")
torch.cuda.set_device(self.device)
def init_dist(self):
print(f"Rank {self.rank} is running.")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(self.config.get("master_port") or "29500")
dist.init_process_group("nccl", rank=self.rank, world_size=self.world_size)
def init_group(self):
self.adj_groups = [dist.new_group([i, i+1]) for i in range(self.world_size-1)]
================================================
FILE: src/video_crafter.py
================================================
import torch
from diffusers.models import AutoencoderKL, UNet3DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import TextToVideoSDPipeline
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
class VideoCrafterPipeline(TextToVideoSDPipeline):
@register_to_config
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet3DConditionModel,
scheduler: KarrasDiffusionSchedulers,
fps_cond: bool = True,
):
self.fps_cond = fps_cond
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
@torch.no_grad()
def __call__(
self,
*args,
**kwargs,
):
fixed_fps = kwargs.pop("fps", 24)
def post_function(sample):
fps = fixed_fps
unet = self.unet
if self.fps_cond:
fps = torch.tensor([fps], dtype=torch.float64 , device=sample.device)
fps_emb = unet.fps_proj(fps)
fps_emb = fps_emb.to(sample.device, dtype=unet.dtype)
fps_emb = unet.fps_embedding(fps_emb).repeat_interleave(repeats=sample.shape[0], dim=0)
sample += fps_emb
return sample
self.unet.time_embedding.post_act = post_function
# kwargs.pop("fps", None)
return super().__call__(*args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
**kwargs,
):
pipe = TextToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w", **kwargs)
pipe.__class__ = cls
pipe.fps_cond = True
pipe.unet = UNetVideoCrafter.from_pretrained(pretrained_model_name_or_path, **kwargs)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type="sde-dpmsolver++")
return pipe
class UNetVideoCrafter(UNet3DConditionModel):
@register_to_config
def __init__(
self,
sample_size,
in_channels,
out_channels,
down_block_types,
up_block_types,
block_out_channels,
layers_per_block,
downsample_padding,
mid_block_scale_factor,
act_fn,
norm_num_groups,
norm_eps,
cross_attention_dim,
attention_head_dim,
num_attention_heads,
fps_cond: bool = True,
**kwargs
):
self.fps_cond = fps_cond
super().__init__(
sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
**kwargs
)
if self.fps_cond:
self.fps_proj = Timesteps(block_out_channels[0], True, 0)
self.fps_embedding = TimestepEmbedding(
block_out_channels[0],
block_out_channels[0] * 4,
act_fn=act_fn,
)
================================================
FILE: src/video_infinity/__init__.py
================================================
================================================
FILE: src/video_infinity/plugins.py
================================================
import torch
import torch.distributed as dist
import math
def my_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, token_num_scale=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
base_scale_factor = 1 / math.sqrt(query.size(-1)) * (scale if scale is not None else 1.)
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.dtype).to(query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype).to(query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask.to(query.dtype).to(query.device)
no_mask_count = torch.where(attn_bias < -100, 0, 1).sum(1)
biased_scale_factor = torch.log(no_mask_count) / torch.log(torch.tensor(16)) if token_num_scale else 1.
scale_factor = (base_scale_factor * biased_scale_factor).unsqueeze(-1) if token_num_scale else base_scale_factor
attn_weight = query @ key.transpose(-2, -1)
attn_weight *= scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
class ModulePlugin:
def __init__(self, module, module_id, global_state=None):
self.module = module
self.module_id = module_id
self.global_state = global_state
self.enable = True
self.implement_forward()
@property
def is_log_node(self):
return self.global_state.get('dist_controller').rank == 0 and self.module_id[1] == 0
@property
def t(self):
return self.global_state.get('timestep')
@property
def p(self):
return self.t / 1000
def implement_forward(self):
module = self.module
if not hasattr(module, "old_forward"):
module.old_forward = module.forward
self.new_forward = self.get_new_forward()
def forward(*args, **kwargs):
self.update_config() # update config
return self.new_forward(*args, **kwargs) if self.enable else self.old_forward(*args, **kwargs)
module.forward = forward
def set_enable(self, enable=True):
self.enable = enable
def get_new_forward(self):
raise NotImplementedError
def update_config(self, config:dict=None):
if config is None:
config = self.global_state.get('plugin_configs', {}).get(self.module_id[0], {})
for key, value in config.items():
setattr(self, key, value)
class GroupNormPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
def get_new_forward(self):
module = self.module
def new_forward(x):
shape = x.shape
N, C, G = shape[0], shape[1], module.num_groups
assert C % G == 0
x = x.reshape(N, G, -1)
mean = x.mean(-1, keepdim=True)
dist.all_reduce(mean)
mean = mean / dist.get_world_size()
var = ((x - mean) ** 2).mean(-1, keepdim=True)
dist.all_reduce(var)
var = var / dist.get_world_size()
x = (x - mean) / (var + module.eps).sqrt()
x = x.view(shape)
new_shape = [1 for _ in shape]
new_shape[1] = -1
return x * module.weight.view(new_shape) + module.bias.view(new_shape)
return new_forward
class ConvLayerPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.padding = 4
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h, padding=None):
padding = self.padding if padding is None else padding
share_to_left = h[:, :, :padding].contiguous()
share_to_right = h[:, :, -padding:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=2)
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
hidden_states = (
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
)
identity = hidden_states
hidden_states = self.pad_context(hidden_states)
hidden_states = module.conv1(hidden_states)
hidden_states = module.conv2(hidden_states)
hidden_states = module.conv3(hidden_states)
hidden_states = module.conv4(hidden_states)
hidden_states = hidden_states[:, :, self.padding:-self.padding]
hidden_states = identity + hidden_states
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
return new_forward
class AttentionPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.padding = 24
self.top_k = 16
self.top_k_chunk_size = 24
self.attn_scale = 1.
self.token_num_scale = False
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
self.world_size = self.global_state.get('dist_controller').world_size
self.dynamic_scale = False
def pad_context(self, h, padding=None):
padding = self.padding if padding is None else padding
share_to_left = h[:, :padding].contiguous()
share_to_right = h[:, -padding:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=1)
return h_with_context, padding
def get_topk(self, q, k, v, top_k=None):
# h = [N, F, C]
top_k = self.top_k if top_k is None else top_k
share_num = int(max(top_k // self.world_size, 0))
stride = max(q.shape[1] // share_num, 1) if share_num else 1000000
topk_indices = torch.arange(0, q.shape[1], stride, device=q.device)
k_to_share, v_to_share = k[:, topk_indices], v[:, topk_indices]
gather_k = [torch.zeros_like(k_to_share) for _ in range(self.world_size)]
gather_v = [torch.zeros_like(v_to_share) for _ in range(self.world_size)]
dist.all_gather(gather_k, k_to_share)
dist.all_gather(gather_v, v_to_share)
gather_k = torch.cat(gather_k, dim=1)[:, :top_k]
gather_v = torch.cat(gather_v, dim=1)[:, :top_k]
return gather_k, gather_v
def gather_context(self, h):
self.temporal_n = h.shape[1]
stack_list = [torch.zeros_like(h) for _ in range(self.world_size)]
dist.all_gather(stack_list, h)
return torch.cat(stack_list, dim=1)
def get_new_forward(self):
module = self.module
def new_forward(x, encoder_hidden_states=None, attention_mask=None):
context=encoder_hidden_states
temporal_n = x.shape[1]
q = module.to_q(x)
context = x if context is None else context
k, v = module.to_k(context), module.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[1], module.heads, -1).permute(0, 2, 1, 3).reshape(b*module.heads, t.shape[1], -1),
(q, k, v),
)
global_k, global_v = self.get_topk(q, k, v)
num_global = global_k.shape[1]
padded_k, _ = self.pad_context(k)
padded_v, padding = self.pad_context(v)
padded_k = torch.cat([padded_k, global_k], dim=1)
padded_v = torch.cat([padded_v, global_v], dim=1)
# if self.is_log_node:
# print("Total KV num:", padding*2 + global_k.shape[1], "Global KV num:", num_global, "Padding:", padding)
attn_mask = torch.ones(temporal_n, temporal_n + 2*padding + num_global, dtype=q.dtype).to(q.device)
for i in range(temporal_n):
attn_mask[i, 0: max(0, i)] = float('-inf')
attn_mask[i, min(temporal_n+2*padding, i+1+2*padding): temporal_n+2*padding] = float('-inf')
if self.dynamic_scale and self.local_phase is not None and self.global_phase is not None:
if self.t < self.local_phase['t']:
attn_mask[:, temporal_n+2*padding:] += self.local_phase['global_biase']
attn_mask[:, :temporal_n+2*padding] += self.local_phase['local_biase']
if self.t >= self.global_phase['t']:
attn_mask[:, temporal_n+2*padding:] += self.global_phase['global_biase']
attn_mask[:, :temporal_n+2*padding] += self.global_phase['local_biase']
out = my_attention(
q, padded_k, padded_v,
attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
scale=self.attn_scale,
token_num_scale=self.token_num_scale
)
out = (
out.unsqueeze(0).reshape(b, module.heads, out.shape[1], -1).permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1)
)
# linear proj
hidden_states = module.to_out[0](out)
hidden_states = module.to_out[1](hidden_states)
return hidden_states
return new_forward
class Conv3DPligin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.padding = 1
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h):
padding = self.padding
share_to_left = h[:, :, :padding].contiguous()
share_to_right = h[:, :, -padding:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=2)
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pad_context(hidden_states)
hidden_states = module.old_forward(hidden_states)[:,:,self.padding:-self.padding]
return hidden_states
return new_forward
class UNetPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
def get_new_forward(self):
module = self.module
def new_forward(*args, **kwargs):
self.global_state.set('timestep', args[1].item())
return module.old_forward(*args, **kwargs)
return new_forward
================================================
FILE: src/video_infinity/wrapper.py
================================================
from ..tools import save_generation, GlobalState, DistController
from .plugins import torch, ModulePlugin, UNetPlugin, GroupNormPlugin, ConvLayerPlugin, AttentionPlugin, Conv3DPligin, dist
class DistWrapper(object):
def __init__(self, pipe, dist_controller: DistController, config) -> None:
super().__init__()
self.pipe = pipe
self.dist_controller = dist_controller
self.config = config
self.global_state = GlobalState({
"dist_controller": dist_controller
})
self.plugin_mount()
def switch_plugin(self, plugin_name, enable):
if plugin_name not in self.plugins: return
for moudule_id in self.plugins[plugin_name]:
moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]
moudle.set_enable(enable)
def config_plugin(self, plugin_name, config):
if plugin_name not in self.plugins: return
for moudule_id in self.plugins[plugin_name]:
moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]
moudle.update_config(config)
def plugin_mount(self):
self.plugins = {}
self.unet_plugin_mount()
self.attn_plugin_mount()
# self.group_norm_plugin_mount()
# self.conv_3d_plugin_mount()
# Conv3d and Conv layer can only be used one at a time
self.conv_plugin_mount()
def group_norm_plugin_mount(self):
self.plugins['group_norm'] = {}
group_norms = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'GroupNorm':
group_norms.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(group_norms)} group norms')
for i, group_norm in enumerate(group_norms):
plugin_id = 'group_norm', i
self.plugins['group_norm'][plugin_id] = GroupNormPlugin(group_norm, plugin_id, self.global_state)
def conv_plugin_mount(self):
self.plugins['conv_layer'] = {}
convs = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'TemporalConvLayer':
convs.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(convs)} convs')
for i, conv in enumerate(convs):
plugin_id = 'conv_layer', i
self.plugins['conv_layer'][plugin_id] = ConvLayerPlugin(conv, plugin_id, self.global_state)
def conv_3d_plugin_mount(self):
self.plugins['conv_3d'] = {}
conv3d_s = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Conv3d':
conv3d_s.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(conv3d_s)} conv3d_s')
for i, conv in enumerate(conv3d_s):
plugin_id = 'conv_3d', i
self.plugins['conv_3d'][plugin_id] = Conv3DPligin(conv, plugin_id, self.global_state)
def attn_plugin_mount(self):
self.plugins['attn'] = {}
attns = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Attention':
attns.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(attns)} attns')
for i, attn in enumerate(attns):
plugin_id = 'attn', i
self.plugins['attn'][plugin_id] = AttentionPlugin(attn, plugin_id, self.global_state)
def unet_plugin_mount(self):
self.plugins['unet'] = UNetPlugin(
self.pipe.unet,
('unet', 0),
self.global_state
)
def inference(
self,
prompts="A beagle wearning diving goggles swimming in the ocean while the camera is moving, coral reefs in the background",
config={},
pipe_configs={
"steps": 50,
"guidance_scale": 12,
"fps": 60,
"num_frames": 24 * 1,
"height": 320,
"width": 512,
"export_fps": 12,
"base_path": "./work/output",
"file_name": None
},
plugin_configs={
"attn":{
"padding": 24,
"top_k": 24,
"top_k_chunk_size": 24,
"attn_scale": 1.,
"token_num_scale": True,
"dynamic_scale": True,
},
"conv_3d": {
"padding": 1,
},
"conv_layer": {},
},
additional_info={},
):
self.plugin_mount()
generator = torch.Generator("cuda").manual_seed(self.config["seed"] + self.dist_controller.rank)
# generator = torch.Generator("cuda").manual_seed(self.config["seed"])
self.global_state.set("plugin_configs", plugin_configs)
video_frames = self.pipe(
prompts,
num_inference_steps=pipe_configs["steps"],
guidance_scale=pipe_configs["guidance_scale"],
height=pipe_configs['height'],
width=pipe_configs['width'],
num_frames=pipe_configs['num_frames'],
fps=pipe_configs['fps'],
generator=generator
).frames[0]
video_frames = torch.tensor(video_frames, dtype=torch.float16, device=self.dist_controller.device)
print(f"Rank {self.dist_controller.rank} finished inference. Result: {video_frames.shape}")
all_frames = [
torch.zeros_like(video_frames, dtype=torch.float16) for _ in range(self.dist_controller.world_size)
] if self.dist_controller.is_master else None
dist.gather(video_frames, all_frames, dst=0)
if self.dist_controller.is_master:
all_frames = torch.cat(all_frames, dim=0).cpu().numpy()
save_generation(
all_frames,
{
"prompt": prompts,
"pipe_configs": pipe_configs,
"plugin_configs": plugin_configs,
"additional_info": additional_info
},
config["base_path"],
pipe_configs["file_name"]
)
gitextract_0rdk_80z/
├── README.md
├── examples/
│ ├── config.json
│ ├── multi_promts.json
│ └── single_gpu.json
├── inference.py
├── requirements.txt
└── src/
├── tools.py
├── video_crafter.py
└── video_infinity/
├── __init__.py
├── plugins.py
└── wrapper.py
SYMBOL INDEX (62 symbols across 5 files)
FILE: inference.py
function parse_args (line 15) | def parse_args():
function init_pipeline (line 22) | def init_pipeline(config):
function run_inference (line 33) | def run_inference(rank, world_size, config):
function main (line 57) | def main(config):
FILE: src/tools.py
function export_to_video (line 9) | def export_to_video(video_frames, output_video_path, fps = 12):
function save_generation (line 22) | def save_generation(video_frames, configs, base_path, file_name=None):
class GlobalState (line 44) | class GlobalState:
method __init__ (line 45) | def __init__(self, state={}) -> None:
method init_state (line 48) | def init_state(self, state={}):
method set (line 51) | def set(self, key, value):
method get (line 54) | def get(self, key, default=None):
class DistController (line 58) | class DistController(object):
method __init__ (line 59) | def __init__(self, rank, world_size, config) -> None:
method init_dist (line 70) | def init_dist(self):
method init_group (line 76) | def init_group(self):
FILE: src/video_crafter.py
class VideoCrafterPipeline (line 11) | class VideoCrafterPipeline(TextToVideoSDPipeline):
method __init__ (line 13) | def __init__(
method __call__ (line 32) | def __call__(
method from_pretrained (line 53) | def from_pretrained(
class UNetVideoCrafter (line 65) | class UNetVideoCrafter(UNet3DConditionModel):
method __init__ (line 67) | def __init__(
FILE: src/video_infinity/plugins.py
function my_attention (line 5) | def my_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_ca...
class ModulePlugin (line 31) | class ModulePlugin:
method __init__ (line 32) | def __init__(self, module, module_id, global_state=None):
method is_log_node (line 40) | def is_log_node(self):
method t (line 44) | def t(self):
method p (line 48) | def p(self):
method implement_forward (line 51) | def implement_forward(self):
method set_enable (line 61) | def set_enable(self, enable=True):
method get_new_forward (line 64) | def get_new_forward(self):
method update_config (line 67) | def update_config(self, config:dict=None):
class GroupNormPlugin (line 74) | class GroupNormPlugin(ModulePlugin):
method __init__ (line 75) | def __init__(self, module, module_id, global_state=None):
method get_new_forward (line 78) | def get_new_forward(self):
class ConvLayerPlugin (line 106) | class ConvLayerPlugin(ModulePlugin):
method __init__ (line 107) | def __init__(self, module, module_id, global_state=None):
method pad_context (line 113) | def pad_context(self, h, padding=None):
method get_new_forward (line 155) | def get_new_forward(self):
class AttentionPlugin (line 182) | class AttentionPlugin(ModulePlugin):
method __init__ (line 183) | def __init__(self, module, module_id, global_state=None):
method pad_context (line 195) | def pad_context(self, h, padding=None):
method get_topk (line 239) | def get_topk(self, q, k, v, top_k=None):
method gather_context (line 261) | def gather_context(self, h):
method get_new_forward (line 267) | def get_new_forward(self):
class Conv3DPligin (line 329) | class Conv3DPligin(ModulePlugin):
method __init__ (line 330) | def __init__(self, module, module_id, global_state=None):
method pad_context (line 336) | def pad_context(self, h):
method get_new_forward (line 378) | def get_new_forward(self):
class UNetPlugin (line 387) | class UNetPlugin(ModulePlugin):
method __init__ (line 388) | def __init__(self, module, module_id, global_state=None):
method get_new_forward (line 391) | def get_new_forward(self):
FILE: src/video_infinity/wrapper.py
class DistWrapper (line 5) | class DistWrapper(object):
method __init__ (line 6) | def __init__(self, pipe, dist_controller: DistController, config) -> N...
method switch_plugin (line 16) | def switch_plugin(self, plugin_name, enable):
method config_plugin (line 22) | def config_plugin(self, plugin_name, config):
method plugin_mount (line 29) | def plugin_mount(self):
method group_norm_plugin_mount (line 41) | def group_norm_plugin_mount(self):
method conv_plugin_mount (line 53) | def conv_plugin_mount(self):
method conv_3d_plugin_mount (line 65) | def conv_3d_plugin_mount(self):
method attn_plugin_mount (line 78) | def attn_plugin_mount(self):
method unet_plugin_mount (line 90) | def unet_plugin_mount(self):
method inference (line 97) | def inference(
Condensed preview — 11 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (45K chars).
[
{
"path": "README.md",
"chars": 5209,
"preview": "\n<div align=\"center\">\n\n# Video-Infinity\n\n<img src='./assets/VideoGen-Main.png' width='80%' />\n<br>\n<a href=\"https://arxi"
},
{
"path": "examples/config.json",
"chars": 1119,
"preview": "{\n \"dtype\": \"torch.float16\",\n \"devices\": [0,1,2,3],\n \"seed\": 123,\n \"master_port\": 29516,\n \"base_path\": \"."
},
{
"path": "examples/multi_promts.json",
"chars": 2379,
"preview": "{\n \"dtype\": \"torch.float16\",\n \"devices\": [\n 0,\n 1,\n 2,\n 3,\n 4,\n 5,\n "
},
{
"path": "examples/single_gpu.json",
"chars": 1113,
"preview": "{\n \"dtype\": \"torch.float16\",\n \"devices\": [0],\n \"seed\": 123,\n \"master_port\": 29516,\n \"base_path\": \"./exp\","
},
{
"path": "inference.py",
"chars": 2110,
"preview": "import torch\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport time\nimport json\nim"
},
{
"path": "requirements.txt",
"chars": 74,
"preview": "torch\ndiffusers\ntransformers\nimageio\naccelerate\nffmpeg\npyav\nimageio-ffmpeg"
},
{
"path": "src/tools.py",
"chars": 2701,
"preview": "import json\nimport numpy as np\nimport imageio\nimport os\n\nimport torch\nimport torch.distributed as dist\n\ndef export_to_vi"
},
{
"path": "src/video_crafter.py",
"chars": 3923,
"preview": "import torch\n\nfrom diffusers.models import AutoencoderKL, UNet3DConditionModel\nfrom transformers import CLIPTextModel, C"
},
{
"path": "src/video_infinity/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/video_infinity/plugins.py",
"chars": 17696,
"preview": "import torch\nimport torch.distributed as dist\nimport math\n\ndef my_attention(query, key, value, attn_mask=None, dropout_p"
},
{
"path": "src/video_infinity/wrapper.py",
"chars": 6489,
"preview": "from ..tools import save_generation, GlobalState, DistController\n\nfrom .plugins import torch, ModulePlugin, UNetPlugin, "
}
]
About this extraction
This page contains the full source code of the Yuanshi9815/Video-Infinity GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 11 files (41.8 KB), approximately 10.4k tokens, and a symbol index with 62 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.