Repository: Yuanshi9815/Video-Infinity
Branch: main
Commit: acfcd17c25cd
Files: 11
Total size: 41.8 KB
Directory structure:
gitextract_0rdk_80z/
├── README.md
├── examples/
│ ├── config.json
│ ├── multi_promts.json
│ └── single_gpu.json
├── inference.py
├── requirements.txt
└── src/
├── tools.py
├── video_crafter.py
└── video_infinity/
├── __init__.py
├── plugins.py
└── wrapper.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
> **Video-Infinity: Distributed Long Video Generation**
>
> Zhenxiong Tan,
> [Xingyi Yang](https://adamdad.github.io/),
> [Songhua Liu](http://121.37.94.87/),
> and
> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
>
> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
>
## TL;DR (Too Long; Didn't Read)
Video-Infinity generates long videos quickly using multiple GPUs without extra training. Feel free to visit our
[project page](https://video-infinity.tanzhenxiong.com)
for more information and generated videos.
## Features
* **Distributed 🌐**: Utilizes multiple GPUs to generate long-form videos.
* **High-Speed 🚀**: Produces 2,300 frames in just 5 minutes.
* **Training-Free 🎓**: Generates long videos without requiring additional training for existing models.
## Setup
### Installation Environment
```bash
conda create -n video_infinity_vc2 python=3.10
conda activate video_infinity_vc2
pip install -r requirements.txt
```
## Usage
### Quick Start
- **Basic Usage**
```bash
python inference.py --config examples/config.json
```
- **Multi-Prompts**
```bash
python inference.py --config examples/multi_prompts.json
```
- **Single GPU**
```bash
python inference.py --config examples/single_gpu.json
```
### Config
#### Basic Config
| Parameter | Description |
| ----------- | -------------------------------------- |
| `devices` | The list of GPU devices to use. |
| `base_path` | The path to save the generated videos. |
#### Pipeline Config
| Parameter | Description |
| ------------ | ---------------------------------------------------------------------------------------------------- |
| `prompts` | The list of text prompts. **Note**: The number of prompts should be greater than the number of GPUs. |
| `file_name` | The name of the generated video. |
| `num_frames` | The number of frames to generate on **each GPU**. |
#### Video-Infinity Config
| Parameter | Description |
| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `*.padding` | The number of local context frames. |
| `attn.topk` | The number of global context frames for `Attention` model. |
| `attn.local_phase` | When the denoise timestep is less than `t`, it bias the attention. This adds a `local_bias` to the local context frames and a `global_bias` to the global context frames. |
| `attn.global_phase` | It is similar to `local_phase`. But it bias the attention when the denoise timestep is greater than `t`. |
| `attn.token_num_scale` | If the value is `True`, the scale factor will be rescaled by the number of tokens. Default is `False`. More details can be referred to this [paper](https://arxiv.org/abs/2306.08645). |
#### How to Set Config
- To avoid the loss of high-frequency information, we recommend setting the sum of `padding` and `attn.topk` to be less than 24 (which is similar to the number of the default frames in the `VideoCrafter2` model).
- If you wish to have a larger `padding` or `attn.topk`, you should set the `attn.token_num_scale` to `True`.
- A higher `local_phase.t` and `global_phase.t` will result in more stable videos but may reduce the diversity of the videos.
- More `padding` will provide more local context.
- A higher `attn.topk` will bring about overall stability in the videos.
## Citation
```
@article{
tan2024videoinf,
title={Video-Infinity: Distributed Long Video Generation},
author={Zhenxiong Tan, Xingyi Yang, Songhua Liu, and Xinchao Wang},
journal={arXiv preprint arXiv:2406.16260},
year={2024}
}
```
## Acknowledgements
Our project is based on the [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2) model. We would like to thank the authors for their excellent work! ❤️
================================================
FILE: examples/config.json
================================================
{
"dtype": "torch.float16",
"devices": [0,1,2,3],
"seed": 123,
"master_port": 29516,
"base_path": "./exp",
"pipe_configs":{
"prompts": [
"A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background."
],
"steps": 30,
"guidance_scale": 12,
"fps": 24,
"num_frames": 24,
"height": 320,
"width": 512,
"export_fps": 8,
"file_name": null
},
"plugin_configs":{
"attn":{
"padding": 8,
"top_k": 16,
"top_k_chunk_size": 24,
"attn_scale": 1.0,
"token_num_scale": false,
"dynamic_scale": true,
"local_phase": {
"t": 800,
"local_biase": 10,
"global_biase": 0
},
"global_phase": {
"t": 800,
"local_biase": 0,
"global_biase": 10
}
},
"conv_3d": {
"padding": 1
},
"conv_layer": {}
}
}
================================================
FILE: examples/multi_promts.json
================================================
{
"dtype": "torch.float16",
"devices": [
0,
1,
2,
3,
4,
5,
6,
7
],
"seed": 123,
"master_port": 29516,
"base_path": "./exp",
"pipe_configs": {
"prompts": [
"[Ukiyo-e style] A black Akita puppy stands alone under the eaves of a traditional Japanese house, shivering in the rain and looking scared.",
"[Ukiyo-e style] A girl in a blue kimono comforts a shivering black Akita puppy during a rainy walk through a stone-paved village.",
"[Ukiyo-e style] A girl in a blue kimono brings an black Akita puppy into her warm home, a traditional wooden Japanese house with sliding doors.",
"[Ukiyo-e style] A girl in a blue kimono plays with her energetic black Akita puppy in the garden of a traditional Japanese house, throwing a woven ball.",
"[Ukiyo-e style] A teenager in a blue kimono jogs with her black Akita through a park filled with cherry blossom trees and ancient stone lanterns.",
"[Ukiyo-e style] A teenage girl in a blue kimono relaxes in a field of wildflowers, reading a scroll while her black Akita rests beside him under a cherry tree.",
"[Ukiyo-e style] A young girl in a blue kimono celebrates her coming-of-age ceremony with her black Akita, surrounded by festive lanterns and banners.",
"[Ukiyo-e style] A girl in a blue kimono and her loyal black Akita enjoy a serene sunset walk along the beach, with the silhouette of Mount Fuji in the distance."
],
"steps": 30,
"guidance_scale": 12,
"fps": 24,
"num_frames": 24,
"height": 320,
"width": 512,
"export_fps": 8,
"file_name": null
},
"plugin_configs": {
"attn": {
"padding": 8,
"top_k": 16,
"top_k_chunk_size": 24,
"attn_scale": 1.0,
"token_num_scale": false,
"dynamic_scale": true,
"local_phase": {
"t": 850,
"local_biase": 10,
"global_biase": 0
},
"global_phase": {
"t": 850,
"local_biase": 0,
"global_biase": 10
}
},
"conv_3d": {
"padding": 1
},
"conv_layer": {}
}
}
================================================
FILE: examples/single_gpu.json
================================================
{
"dtype": "torch.float16",
"devices": [0],
"seed": 123,
"master_port": 29516,
"base_path": "./exp",
"pipe_configs":{
"prompts": [
"A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background."
],
"steps": 30,
"guidance_scale": 12,
"fps": 24,
"num_frames": 24,
"height": 320,
"width": 512,
"export_fps": 8,
"file_name": null
},
"plugin_configs":{
"attn":{
"padding": 8,
"top_k": 16,
"top_k_chunk_size": 24,
"attn_scale": 1.0,
"token_num_scale": false,
"dynamic_scale": true,
"local_phase": {
"t": 800,
"local_biase": 10,
"global_biase": 0
},
"global_phase": {
"t": 800,
"local_biase": 0,
"global_biase": 10
}
},
"conv_3d": {
"padding": 1
},
"conv_layer": {}
}
}
================================================
FILE: inference.py
================================================
import torch
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import json
import os
from src.video_crafter import VideoCrafterPipeline, UNetVideoCrafter
from diffusers.schedulers import DPMSolverMultistepScheduler
from src.tools import DistController
from src.video_infinity.wrapper import DistWrapper
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Video Infinity Inference")
parser.add_argument("--config", type=str)
args = parser.parse_args()
return args
def init_pipeline(config):
pipe = VideoCrafterPipeline.from_pretrained(
'adamdad/videocrafterv2_diffusers',
torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload(
gpu_id=config["devices"][dist.get_rank() % len(config["devices"])],
)
pipe.enable_vae_slicing()
return pipe
def run_inference(rank, world_size, config):
dist_controller = DistController(rank, world_size, config)
pipe = init_pipeline(config)
dist_pipe = DistWrapper(pipe, dist_controller, config)
start = time.time()
pipe_configs=config['pipe_configs']
plugin_configs=config['plugin_configs']
prompt_id = int(rank / world_size * len(pipe_configs["prompts"]))
prompt = pipe_configs["prompts"][prompt_id]
start = time.time()
dist_pipe.inference(
prompt,
config,
pipe_configs,
plugin_configs,
additional_info={
"full_config": config,
}
)
print(f"Rank {rank} finished. Time: {time.time() - start}")
def main(config):
size = len(config["devices"])
processes = []
if not os.path.exists(config["base_path"]):
os.makedirs(config["base_path"])
for rank, _ in enumerate(config["devices"]):
p = mp.Process(target=run_inference, args=(rank, size, config))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
mp.set_start_method("spawn")
with open(parse_args().config, "r") as f:
config = json.load(f)
main(config)
================================================
FILE: requirements.txt
================================================
torch
diffusers
transformers
imageio
accelerate
ffmpeg
pyav
imageio-ffmpeg
================================================
FILE: src/tools.py
================================================
import json
import numpy as np
import imageio
import os
import torch
import torch.distributed as dist
def export_to_video(video_frames, output_video_path, fps = 12):
# Ensure all frames are NumPy arrays and determine video dimensions from the first frame
assert all(isinstance(frame, np.ndarray) for frame in video_frames), "All video frames must be NumPy arrays."
# Ensure output_video_path is ending with .mp4
if not output_video_path.endswith('.mp4'):
output_video_path += '.mp4'
# Create a video file at the specified path and write frames to it
with imageio.get_writer(output_video_path, fps=fps, format='mp4') as writer:
for frame in video_frames:
writer.append_data(
(frame * 255).astype(np.uint8)
)
def save_generation(video_frames, configs, base_path, file_name=None):
if not os.path.exists(base_path):
os.makedirs(base_path)
p_config = configs["pipe_configs"]
frames, steps, fps = p_config["num_frames"], p_config["steps"], p_config["fps"]
if not file_name:
index = [int(each.split('_')[0]) for each in os.listdir(base_path)]
max_idex = max(index) if index else 0
idx_str = str(max_idex + 1).zfill(6)
key_info = '_'.join([str(frames), str(steps), str(fps)])
file_name = f'{idx_str}_{key_info}'
with open(f'{base_path}/{file_name}.json', 'w') as f:
json.dump(configs, f, indent=4)
export_to_video(video_frames, f'{base_path}/{file_name}.mp4', fps=p_config["export_fps"])
return file_name
class GlobalState:
def __init__(self, state={}) -> None:
self.init_state(state)
def init_state(self, state={}):
self.state = state
def set(self, key, value):
self.state[key] = value
def get(self, key, default=None):
return self.state.get(key, default)
class DistController(object):
def __init__(self, rank, world_size, config) -> None:
super().__init__()
self.rank = rank
self.world_size = world_size
self.config = config
self.is_master = rank == 0
self.init_dist()
self.init_group()
self.device = torch.device(f"cuda:{config['devices'][dist.get_rank()]}")
torch.cuda.set_device(self.device)
def init_dist(self):
print(f"Rank {self.rank} is running.")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(self.config.get("master_port") or "29500")
dist.init_process_group("nccl", rank=self.rank, world_size=self.world_size)
def init_group(self):
self.adj_groups = [dist.new_group([i, i+1]) for i in range(self.world_size-1)]
================================================
FILE: src/video_crafter.py
================================================
import torch
from diffusers.models import AutoencoderKL, UNet3DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import TextToVideoSDPipeline
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
class VideoCrafterPipeline(TextToVideoSDPipeline):
@register_to_config
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet3DConditionModel,
scheduler: KarrasDiffusionSchedulers,
fps_cond: bool = True,
):
self.fps_cond = fps_cond
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
@torch.no_grad()
def __call__(
self,
*args,
**kwargs,
):
fixed_fps = kwargs.pop("fps", 24)
def post_function(sample):
fps = fixed_fps
unet = self.unet
if self.fps_cond:
fps = torch.tensor([fps], dtype=torch.float64 , device=sample.device)
fps_emb = unet.fps_proj(fps)
fps_emb = fps_emb.to(sample.device, dtype=unet.dtype)
fps_emb = unet.fps_embedding(fps_emb).repeat_interleave(repeats=sample.shape[0], dim=0)
sample += fps_emb
return sample
self.unet.time_embedding.post_act = post_function
# kwargs.pop("fps", None)
return super().__call__(*args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
**kwargs,
):
pipe = TextToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w", **kwargs)
pipe.__class__ = cls
pipe.fps_cond = True
pipe.unet = UNetVideoCrafter.from_pretrained(pretrained_model_name_or_path, **kwargs)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type="sde-dpmsolver++")
return pipe
class UNetVideoCrafter(UNet3DConditionModel):
@register_to_config
def __init__(
self,
sample_size,
in_channels,
out_channels,
down_block_types,
up_block_types,
block_out_channels,
layers_per_block,
downsample_padding,
mid_block_scale_factor,
act_fn,
norm_num_groups,
norm_eps,
cross_attention_dim,
attention_head_dim,
num_attention_heads,
fps_cond: bool = True,
**kwargs
):
self.fps_cond = fps_cond
super().__init__(
sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
**kwargs
)
if self.fps_cond:
self.fps_proj = Timesteps(block_out_channels[0], True, 0)
self.fps_embedding = TimestepEmbedding(
block_out_channels[0],
block_out_channels[0] * 4,
act_fn=act_fn,
)
================================================
FILE: src/video_infinity/__init__.py
================================================
================================================
FILE: src/video_infinity/plugins.py
================================================
import torch
import torch.distributed as dist
import math
def my_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, token_num_scale=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
base_scale_factor = 1 / math.sqrt(query.size(-1)) * (scale if scale is not None else 1.)
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.dtype).to(query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype).to(query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask.to(query.dtype).to(query.device)
no_mask_count = torch.where(attn_bias < -100, 0, 1).sum(1)
biased_scale_factor = torch.log(no_mask_count) / torch.log(torch.tensor(16)) if token_num_scale else 1.
scale_factor = (base_scale_factor * biased_scale_factor).unsqueeze(-1) if token_num_scale else base_scale_factor
attn_weight = query @ key.transpose(-2, -1)
attn_weight *= scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
class ModulePlugin:
def __init__(self, module, module_id, global_state=None):
self.module = module
self.module_id = module_id
self.global_state = global_state
self.enable = True
self.implement_forward()
@property
def is_log_node(self):
return self.global_state.get('dist_controller').rank == 0 and self.module_id[1] == 0
@property
def t(self):
return self.global_state.get('timestep')
@property
def p(self):
return self.t / 1000
def implement_forward(self):
module = self.module
if not hasattr(module, "old_forward"):
module.old_forward = module.forward
self.new_forward = self.get_new_forward()
def forward(*args, **kwargs):
self.update_config() # update config
return self.new_forward(*args, **kwargs) if self.enable else self.old_forward(*args, **kwargs)
module.forward = forward
def set_enable(self, enable=True):
self.enable = enable
def get_new_forward(self):
raise NotImplementedError
def update_config(self, config:dict=None):
if config is None:
config = self.global_state.get('plugin_configs', {}).get(self.module_id[0], {})
for key, value in config.items():
setattr(self, key, value)
class GroupNormPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
def get_new_forward(self):
module = self.module
def new_forward(x):
shape = x.shape
N, C, G = shape[0], shape[1], module.num_groups
assert C % G == 0
x = x.reshape(N, G, -1)
mean = x.mean(-1, keepdim=True)
dist.all_reduce(mean)
mean = mean / dist.get_world_size()
var = ((x - mean) ** 2).mean(-1, keepdim=True)
dist.all_reduce(var)
var = var / dist.get_world_size()
x = (x - mean) / (var + module.eps).sqrt()
x = x.view(shape)
new_shape = [1 for _ in shape]
new_shape[1] = -1
return x * module.weight.view(new_shape) + module.bias.view(new_shape)
return new_forward
class ConvLayerPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.padding = 4
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h, padding=None):
padding = self.padding if padding is None else padding
share_to_left = h[:, :, :padding].contiguous()
share_to_right = h[:, :, -padding:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=2)
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
hidden_states = (
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
)
identity = hidden_states
hidden_states = self.pad_context(hidden_states)
hidden_states = module.conv1(hidden_states)
hidden_states = module.conv2(hidden_states)
hidden_states = module.conv3(hidden_states)
hidden_states = module.conv4(hidden_states)
hidden_states = hidden_states[:, :, self.padding:-self.padding]
hidden_states = identity + hidden_states
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
return new_forward
class AttentionPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.padding = 24
self.top_k = 16
self.top_k_chunk_size = 24
self.attn_scale = 1.
self.token_num_scale = False
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
self.world_size = self.global_state.get('dist_controller').world_size
self.dynamic_scale = False
def pad_context(self, h, padding=None):
padding = self.padding if padding is None else padding
share_to_left = h[:, :padding].contiguous()
share_to_right = h[:, -padding:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=1)
return h_with_context, padding
def get_topk(self, q, k, v, top_k=None):
# h = [N, F, C]
top_k = self.top_k if top_k is None else top_k
share_num = int(max(top_k // self.world_size, 0))
stride = max(q.shape[1] // share_num, 1) if share_num else 1000000
topk_indices = torch.arange(0, q.shape[1], stride, device=q.device)
k_to_share, v_to_share = k[:, topk_indices], v[:, topk_indices]
gather_k = [torch.zeros_like(k_to_share) for _ in range(self.world_size)]
gather_v = [torch.zeros_like(v_to_share) for _ in range(self.world_size)]
dist.all_gather(gather_k, k_to_share)
dist.all_gather(gather_v, v_to_share)
gather_k = torch.cat(gather_k, dim=1)[:, :top_k]
gather_v = torch.cat(gather_v, dim=1)[:, :top_k]
return gather_k, gather_v
def gather_context(self, h):
self.temporal_n = h.shape[1]
stack_list = [torch.zeros_like(h) for _ in range(self.world_size)]
dist.all_gather(stack_list, h)
return torch.cat(stack_list, dim=1)
def get_new_forward(self):
module = self.module
def new_forward(x, encoder_hidden_states=None, attention_mask=None):
context=encoder_hidden_states
temporal_n = x.shape[1]
q = module.to_q(x)
context = x if context is None else context
k, v = module.to_k(context), module.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[1], module.heads, -1).permute(0, 2, 1, 3).reshape(b*module.heads, t.shape[1], -1),
(q, k, v),
)
global_k, global_v = self.get_topk(q, k, v)
num_global = global_k.shape[1]
padded_k, _ = self.pad_context(k)
padded_v, padding = self.pad_context(v)
padded_k = torch.cat([padded_k, global_k], dim=1)
padded_v = torch.cat([padded_v, global_v], dim=1)
# if self.is_log_node:
# print("Total KV num:", padding*2 + global_k.shape[1], "Global KV num:", num_global, "Padding:", padding)
attn_mask = torch.ones(temporal_n, temporal_n + 2*padding + num_global, dtype=q.dtype).to(q.device)
for i in range(temporal_n):
attn_mask[i, 0: max(0, i)] = float('-inf')
attn_mask[i, min(temporal_n+2*padding, i+1+2*padding): temporal_n+2*padding] = float('-inf')
if self.dynamic_scale and self.local_phase is not None and self.global_phase is not None:
if self.t < self.local_phase['t']:
attn_mask[:, temporal_n+2*padding:] += self.local_phase['global_biase']
attn_mask[:, :temporal_n+2*padding] += self.local_phase['local_biase']
if self.t >= self.global_phase['t']:
attn_mask[:, temporal_n+2*padding:] += self.global_phase['global_biase']
attn_mask[:, :temporal_n+2*padding] += self.global_phase['local_biase']
out = my_attention(
q, padded_k, padded_v,
attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
scale=self.attn_scale,
token_num_scale=self.token_num_scale
)
out = (
out.unsqueeze(0).reshape(b, module.heads, out.shape[1], -1).permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1)
)
# linear proj
hidden_states = module.to_out[0](out)
hidden_states = module.to_out[1](hidden_states)
return hidden_states
return new_forward
class Conv3DPligin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.padding = 1
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h):
padding = self.padding
share_to_left = h[:, :, :padding].contiguous()
share_to_right = h[:, :, -padding:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=2)
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pad_context(hidden_states)
hidden_states = module.old_forward(hidden_states)[:,:,self.padding:-self.padding]
return hidden_states
return new_forward
class UNetPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
def get_new_forward(self):
module = self.module
def new_forward(*args, **kwargs):
self.global_state.set('timestep', args[1].item())
return module.old_forward(*args, **kwargs)
return new_forward
================================================
FILE: src/video_infinity/wrapper.py
================================================
from ..tools import save_generation, GlobalState, DistController
from .plugins import torch, ModulePlugin, UNetPlugin, GroupNormPlugin, ConvLayerPlugin, AttentionPlugin, Conv3DPligin, dist
class DistWrapper(object):
def __init__(self, pipe, dist_controller: DistController, config) -> None:
super().__init__()
self.pipe = pipe
self.dist_controller = dist_controller
self.config = config
self.global_state = GlobalState({
"dist_controller": dist_controller
})
self.plugin_mount()
def switch_plugin(self, plugin_name, enable):
if plugin_name not in self.plugins: return
for moudule_id in self.plugins[plugin_name]:
moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]
moudle.set_enable(enable)
def config_plugin(self, plugin_name, config):
if plugin_name not in self.plugins: return
for moudule_id in self.plugins[plugin_name]:
moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]
moudle.update_config(config)
def plugin_mount(self):
self.plugins = {}
self.unet_plugin_mount()
self.attn_plugin_mount()
# self.group_norm_plugin_mount()
# self.conv_3d_plugin_mount()
# Conv3d and Conv layer can only be used one at a time
self.conv_plugin_mount()
def group_norm_plugin_mount(self):
self.plugins['group_norm'] = {}
group_norms = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'GroupNorm':
group_norms.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(group_norms)} group norms')
for i, group_norm in enumerate(group_norms):
plugin_id = 'group_norm', i
self.plugins['group_norm'][plugin_id] = GroupNormPlugin(group_norm, plugin_id, self.global_state)
def conv_plugin_mount(self):
self.plugins['conv_layer'] = {}
convs = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'TemporalConvLayer':
convs.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(convs)} convs')
for i, conv in enumerate(convs):
plugin_id = 'conv_layer', i
self.plugins['conv_layer'][plugin_id] = ConvLayerPlugin(conv, plugin_id, self.global_state)
def conv_3d_plugin_mount(self):
self.plugins['conv_3d'] = {}
conv3d_s = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Conv3d':
conv3d_s.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(conv3d_s)} conv3d_s')
for i, conv in enumerate(conv3d_s):
plugin_id = 'conv_3d', i
self.plugins['conv_3d'][plugin_id] = Conv3DPligin(conv, plugin_id, self.global_state)
def attn_plugin_mount(self):
self.plugins['attn'] = {}
attns = []
for module in self.pipe.unet.named_modules():
if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Attention':
attns.append(module[1])
if self.dist_controller.is_master:
print(f'Found {len(attns)} attns')
for i, attn in enumerate(attns):
plugin_id = 'attn', i
self.plugins['attn'][plugin_id] = AttentionPlugin(attn, plugin_id, self.global_state)
def unet_plugin_mount(self):
self.plugins['unet'] = UNetPlugin(
self.pipe.unet,
('unet', 0),
self.global_state
)
def inference(
self,
prompts="A beagle wearning diving goggles swimming in the ocean while the camera is moving, coral reefs in the background",
config={},
pipe_configs={
"steps": 50,
"guidance_scale": 12,
"fps": 60,
"num_frames": 24 * 1,
"height": 320,
"width": 512,
"export_fps": 12,
"base_path": "./work/output",
"file_name": None
},
plugin_configs={
"attn":{
"padding": 24,
"top_k": 24,
"top_k_chunk_size": 24,
"attn_scale": 1.,
"token_num_scale": True,
"dynamic_scale": True,
},
"conv_3d": {
"padding": 1,
},
"conv_layer": {},
},
additional_info={},
):
self.plugin_mount()
generator = torch.Generator("cuda").manual_seed(self.config["seed"] + self.dist_controller.rank)
# generator = torch.Generator("cuda").manual_seed(self.config["seed"])
self.global_state.set("plugin_configs", plugin_configs)
video_frames = self.pipe(
prompts,
num_inference_steps=pipe_configs["steps"],
guidance_scale=pipe_configs["guidance_scale"],
height=pipe_configs['height'],
width=pipe_configs['width'],
num_frames=pipe_configs['num_frames'],
fps=pipe_configs['fps'],
generator=generator
).frames[0]
video_frames = torch.tensor(video_frames, dtype=torch.float16, device=self.dist_controller.device)
print(f"Rank {self.dist_controller.rank} finished inference. Result: {video_frames.shape}")
all_frames = [
torch.zeros_like(video_frames, dtype=torch.float16) for _ in range(self.dist_controller.world_size)
] if self.dist_controller.is_master else None
dist.gather(video_frames, all_frames, dst=0)
if self.dist_controller.is_master:
all_frames = torch.cat(all_frames, dim=0).cpu().numpy()
save_generation(
all_frames,
{
"prompt": prompts,
"pipe_configs": pipe_configs,
"plugin_configs": plugin_configs,
"additional_info": additional_info
},
config["base_path"],
pipe_configs["file_name"]
)