Repository: Yuanshi9815/Video-Infinity Branch: main Commit: acfcd17c25cd Files: 11 Total size: 41.8 KB Directory structure: gitextract_0rdk_80z/ ├── README.md ├── examples/ │ ├── config.json │ ├── multi_promts.json │ └── single_gpu.json ├── inference.py ├── requirements.txt └── src/ ├── tools.py ├── video_crafter.py └── video_infinity/ ├── __init__.py ├── plugins.py └── wrapper.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================
# Video-Infinity
arXiv arXiv
> **Video-Infinity: Distributed Long Video Generation** >
> Zhenxiong Tan, > [Xingyi Yang](https://adamdad.github.io/), > [Songhua Liu](http://121.37.94.87/), > and > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) >
> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore >
## TL;DR (Too Long; Didn't Read) Video-Infinity generates long videos quickly using multiple GPUs without extra training. Feel free to visit our [project page](https://video-infinity.tanzhenxiong.com) for more information and generated videos. ## Features * **Distributed 🌐**: Utilizes multiple GPUs to generate long-form videos. * **High-Speed 🚀**: Produces 2,300 frames in just 5 minutes. * **Training-Free 🎓**: Generates long videos without requiring additional training for existing models. ## Setup ### Installation Environment ```bash conda create -n video_infinity_vc2 python=3.10 conda activate video_infinity_vc2 pip install -r requirements.txt ``` ## Usage ### Quick Start - **Basic Usage** ```bash python inference.py --config examples/config.json ``` - **Multi-Prompts** ```bash python inference.py --config examples/multi_prompts.json ``` - **Single GPU** ```bash python inference.py --config examples/single_gpu.json ``` ### Config #### Basic Config | Parameter | Description | | ----------- | -------------------------------------- | | `devices` | The list of GPU devices to use. | | `base_path` | The path to save the generated videos. | #### Pipeline Config | Parameter | Description | | ------------ | ---------------------------------------------------------------------------------------------------- | | `prompts` | The list of text prompts. **Note**: The number of prompts should be greater than the number of GPUs. | | `file_name` | The name of the generated video. | | `num_frames` | The number of frames to generate on **each GPU**. | #### Video-Infinity Config | Parameter | Description | | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `*.padding` | The number of local context frames. | | `attn.topk` | The number of global context frames for `Attention` model. | | `attn.local_phase` | When the denoise timestep is less than `t`, it bias the attention. This adds a `local_bias` to the local context frames and a `global_bias` to the global context frames. | | `attn.global_phase` | It is similar to `local_phase`. But it bias the attention when the denoise timestep is greater than `t`. | | `attn.token_num_scale` | If the value is `True`, the scale factor will be rescaled by the number of tokens. Default is `False`. More details can be referred to this [paper](https://arxiv.org/abs/2306.08645). | #### How to Set Config - To avoid the loss of high-frequency information, we recommend setting the sum of `padding` and `attn.topk` to be less than 24 (which is similar to the number of the default frames in the `VideoCrafter2` model). - If you wish to have a larger `padding` or `attn.topk`, you should set the `attn.token_num_scale` to `True`. - A higher `local_phase.t` and `global_phase.t` will result in more stable videos but may reduce the diversity of the videos. - More `padding` will provide more local context. - A higher `attn.topk` will bring about overall stability in the videos. ## Citation ``` @article{ tan2024videoinf, title={Video-Infinity: Distributed Long Video Generation}, author={Zhenxiong Tan, Xingyi Yang, Songhua Liu, and Xinchao Wang}, journal={arXiv preprint arXiv:2406.16260}, year={2024} } ``` ## Acknowledgements Our project is based on the [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2) model. We would like to thank the authors for their excellent work! ❤️ ================================================ FILE: examples/config.json ================================================ { "dtype": "torch.float16", "devices": [0,1,2,3], "seed": 123, "master_port": 29516, "base_path": "./exp", "pipe_configs":{ "prompts": [ "A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background." ], "steps": 30, "guidance_scale": 12, "fps": 24, "num_frames": 24, "height": 320, "width": 512, "export_fps": 8, "file_name": null }, "plugin_configs":{ "attn":{ "padding": 8, "top_k": 16, "top_k_chunk_size": 24, "attn_scale": 1.0, "token_num_scale": false, "dynamic_scale": true, "local_phase": { "t": 800, "local_biase": 10, "global_biase": 0 }, "global_phase": { "t": 800, "local_biase": 0, "global_biase": 10 } }, "conv_3d": { "padding": 1 }, "conv_layer": {} } } ================================================ FILE: examples/multi_promts.json ================================================ { "dtype": "torch.float16", "devices": [ 0, 1, 2, 3, 4, 5, 6, 7 ], "seed": 123, "master_port": 29516, "base_path": "./exp", "pipe_configs": { "prompts": [ "[Ukiyo-e style] A black Akita puppy stands alone under the eaves of a traditional Japanese house, shivering in the rain and looking scared.", "[Ukiyo-e style] A girl in a blue kimono comforts a shivering black Akita puppy during a rainy walk through a stone-paved village.", "[Ukiyo-e style] A girl in a blue kimono brings an black Akita puppy into her warm home, a traditional wooden Japanese house with sliding doors.", "[Ukiyo-e style] A girl in a blue kimono plays with her energetic black Akita puppy in the garden of a traditional Japanese house, throwing a woven ball.", "[Ukiyo-e style] A teenager in a blue kimono jogs with her black Akita through a park filled with cherry blossom trees and ancient stone lanterns.", "[Ukiyo-e style] A teenage girl in a blue kimono relaxes in a field of wildflowers, reading a scroll while her black Akita rests beside him under a cherry tree.", "[Ukiyo-e style] A young girl in a blue kimono celebrates her coming-of-age ceremony with her black Akita, surrounded by festive lanterns and banners.", "[Ukiyo-e style] A girl in a blue kimono and her loyal black Akita enjoy a serene sunset walk along the beach, with the silhouette of Mount Fuji in the distance." ], "steps": 30, "guidance_scale": 12, "fps": 24, "num_frames": 24, "height": 320, "width": 512, "export_fps": 8, "file_name": null }, "plugin_configs": { "attn": { "padding": 8, "top_k": 16, "top_k_chunk_size": 24, "attn_scale": 1.0, "token_num_scale": false, "dynamic_scale": true, "local_phase": { "t": 850, "local_biase": 10, "global_biase": 0 }, "global_phase": { "t": 850, "local_biase": 0, "global_biase": 10 } }, "conv_3d": { "padding": 1 }, "conv_layer": {} } } ================================================ FILE: examples/single_gpu.json ================================================ { "dtype": "torch.float16", "devices": [0], "seed": 123, "master_port": 29516, "base_path": "./exp", "pipe_configs":{ "prompts": [ "A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background." ], "steps": 30, "guidance_scale": 12, "fps": 24, "num_frames": 24, "height": 320, "width": 512, "export_fps": 8, "file_name": null }, "plugin_configs":{ "attn":{ "padding": 8, "top_k": 16, "top_k_chunk_size": 24, "attn_scale": 1.0, "token_num_scale": false, "dynamic_scale": true, "local_phase": { "t": 800, "local_biase": 10, "global_biase": 0 }, "global_phase": { "t": 800, "local_biase": 0, "global_biase": 10 } }, "conv_3d": { "padding": 1 }, "conv_layer": {} } } ================================================ FILE: inference.py ================================================ import torch import torch import torch.distributed as dist import torch.multiprocessing as mp import time import json import os from src.video_crafter import VideoCrafterPipeline, UNetVideoCrafter from diffusers.schedulers import DPMSolverMultistepScheduler from src.tools import DistController from src.video_infinity.wrapper import DistWrapper def parse_args(): import argparse parser = argparse.ArgumentParser(description="Video Infinity Inference") parser.add_argument("--config", type=str) args = parser.parse_args() return args def init_pipeline(config): pipe = VideoCrafterPipeline.from_pretrained( 'adamdad/videocrafterv2_diffusers', torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload( gpu_id=config["devices"][dist.get_rank() % len(config["devices"])], ) pipe.enable_vae_slicing() return pipe def run_inference(rank, world_size, config): dist_controller = DistController(rank, world_size, config) pipe = init_pipeline(config) dist_pipe = DistWrapper(pipe, dist_controller, config) start = time.time() pipe_configs=config['pipe_configs'] plugin_configs=config['plugin_configs'] prompt_id = int(rank / world_size * len(pipe_configs["prompts"])) prompt = pipe_configs["prompts"][prompt_id] start = time.time() dist_pipe.inference( prompt, config, pipe_configs, plugin_configs, additional_info={ "full_config": config, } ) print(f"Rank {rank} finished. Time: {time.time() - start}") def main(config): size = len(config["devices"]) processes = [] if not os.path.exists(config["base_path"]): os.makedirs(config["base_path"]) for rank, _ in enumerate(config["devices"]): p = mp.Process(target=run_inference, args=(rank, size, config)) p.start() processes.append(p) for p in processes: p.join() if __name__ == "__main__": mp.set_start_method("spawn") with open(parse_args().config, "r") as f: config = json.load(f) main(config) ================================================ FILE: requirements.txt ================================================ torch diffusers transformers imageio accelerate ffmpeg pyav imageio-ffmpeg ================================================ FILE: src/tools.py ================================================ import json import numpy as np import imageio import os import torch import torch.distributed as dist def export_to_video(video_frames, output_video_path, fps = 12): # Ensure all frames are NumPy arrays and determine video dimensions from the first frame assert all(isinstance(frame, np.ndarray) for frame in video_frames), "All video frames must be NumPy arrays." # Ensure output_video_path is ending with .mp4 if not output_video_path.endswith('.mp4'): output_video_path += '.mp4' # Create a video file at the specified path and write frames to it with imageio.get_writer(output_video_path, fps=fps, format='mp4') as writer: for frame in video_frames: writer.append_data( (frame * 255).astype(np.uint8) ) def save_generation(video_frames, configs, base_path, file_name=None): if not os.path.exists(base_path): os.makedirs(base_path) p_config = configs["pipe_configs"] frames, steps, fps = p_config["num_frames"], p_config["steps"], p_config["fps"] if not file_name: index = [int(each.split('_')[0]) for each in os.listdir(base_path)] max_idex = max(index) if index else 0 idx_str = str(max_idex + 1).zfill(6) key_info = '_'.join([str(frames), str(steps), str(fps)]) file_name = f'{idx_str}_{key_info}' with open(f'{base_path}/{file_name}.json', 'w') as f: json.dump(configs, f, indent=4) export_to_video(video_frames, f'{base_path}/{file_name}.mp4', fps=p_config["export_fps"]) return file_name class GlobalState: def __init__(self, state={}) -> None: self.init_state(state) def init_state(self, state={}): self.state = state def set(self, key, value): self.state[key] = value def get(self, key, default=None): return self.state.get(key, default) class DistController(object): def __init__(self, rank, world_size, config) -> None: super().__init__() self.rank = rank self.world_size = world_size self.config = config self.is_master = rank == 0 self.init_dist() self.init_group() self.device = torch.device(f"cuda:{config['devices'][dist.get_rank()]}") torch.cuda.set_device(self.device) def init_dist(self): print(f"Rank {self.rank} is running.") os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = str(self.config.get("master_port") or "29500") dist.init_process_group("nccl", rank=self.rank, world_size=self.world_size) def init_group(self): self.adj_groups = [dist.new_group([i, i+1]) for i in range(self.world_size-1)] ================================================ FILE: src/video_crafter.py ================================================ import torch from diffusers.models import AutoencoderKL, UNet3DConditionModel from transformers import CLIPTextModel, CLIPTokenizer from diffusers.schedulers import DPMSolverMultistepScheduler from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import TextToVideoSDPipeline from diffusers.configuration_utils import register_to_config from diffusers.models.embeddings import TimestepEmbedding, Timesteps class VideoCrafterPipeline(TextToVideoSDPipeline): @register_to_config def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, fps_cond: bool = True, ): self.fps_cond = fps_cond super().__init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) @torch.no_grad() def __call__( self, *args, **kwargs, ): fixed_fps = kwargs.pop("fps", 24) def post_function(sample): fps = fixed_fps unet = self.unet if self.fps_cond: fps = torch.tensor([fps], dtype=torch.float64 , device=sample.device) fps_emb = unet.fps_proj(fps) fps_emb = fps_emb.to(sample.device, dtype=unet.dtype) fps_emb = unet.fps_embedding(fps_emb).repeat_interleave(repeats=sample.shape[0], dim=0) sample += fps_emb return sample self.unet.time_embedding.post_act = post_function # kwargs.pop("fps", None) return super().__call__(*args, **kwargs) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, **kwargs, ): pipe = TextToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w", **kwargs) pipe.__class__ = cls pipe.fps_cond = True pipe.unet = UNetVideoCrafter.from_pretrained(pretrained_model_name_or_path, **kwargs) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type="sde-dpmsolver++") return pipe class UNetVideoCrafter(UNet3DConditionModel): @register_to_config def __init__( self, sample_size, in_channels, out_channels, down_block_types, up_block_types, block_out_channels, layers_per_block, downsample_padding, mid_block_scale_factor, act_fn, norm_num_groups, norm_eps, cross_attention_dim, attention_head_dim, num_attention_heads, fps_cond: bool = True, **kwargs ): self.fps_cond = fps_cond super().__init__( sample_size=sample_size, in_channels=in_channels, out_channels=out_channels, down_block_types=down_block_types, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, downsample_padding=downsample_padding, mid_block_scale_factor=mid_block_scale_factor, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, **kwargs ) if self.fps_cond: self.fps_proj = Timesteps(block_out_channels[0], True, 0) self.fps_embedding = TimestepEmbedding( block_out_channels[0], block_out_channels[0] * 4, act_fn=act_fn, ) ================================================ FILE: src/video_infinity/__init__.py ================================================ ================================================ FILE: src/video_infinity/plugins.py ================================================ import torch import torch.distributed as dist import math def my_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, token_num_scale=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) base_scale_factor = 1 / math.sqrt(query.size(-1)) * (scale if scale is not None else 1.) attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.dtype).to(query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype).to(query.device) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask.to(query.dtype).to(query.device) no_mask_count = torch.where(attn_bias < -100, 0, 1).sum(1) biased_scale_factor = torch.log(no_mask_count) / torch.log(torch.tensor(16)) if token_num_scale else 1. scale_factor = (base_scale_factor * biased_scale_factor).unsqueeze(-1) if token_num_scale else base_scale_factor attn_weight = query @ key.transpose(-2, -1) attn_weight *= scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value class ModulePlugin: def __init__(self, module, module_id, global_state=None): self.module = module self.module_id = module_id self.global_state = global_state self.enable = True self.implement_forward() @property def is_log_node(self): return self.global_state.get('dist_controller').rank == 0 and self.module_id[1] == 0 @property def t(self): return self.global_state.get('timestep') @property def p(self): return self.t / 1000 def implement_forward(self): module = self.module if not hasattr(module, "old_forward"): module.old_forward = module.forward self.new_forward = self.get_new_forward() def forward(*args, **kwargs): self.update_config() # update config return self.new_forward(*args, **kwargs) if self.enable else self.old_forward(*args, **kwargs) module.forward = forward def set_enable(self, enable=True): self.enable = enable def get_new_forward(self): raise NotImplementedError def update_config(self, config:dict=None): if config is None: config = self.global_state.get('plugin_configs', {}).get(self.module_id[0], {}) for key, value in config.items(): setattr(self, key, value) class GroupNormPlugin(ModulePlugin): def __init__(self, module, module_id, global_state=None): super().__init__(module, module_id, global_state) def get_new_forward(self): module = self.module def new_forward(x): shape = x.shape N, C, G = shape[0], shape[1], module.num_groups assert C % G == 0 x = x.reshape(N, G, -1) mean = x.mean(-1, keepdim=True) dist.all_reduce(mean) mean = mean / dist.get_world_size() var = ((x - mean) ** 2).mean(-1, keepdim=True) dist.all_reduce(var) var = var / dist.get_world_size() x = (x - mean) / (var + module.eps).sqrt() x = x.view(shape) new_shape = [1 for _ in shape] new_shape[1] = -1 return x * module.weight.view(new_shape) + module.bias.view(new_shape) return new_forward class ConvLayerPlugin(ModulePlugin): def __init__(self, module, module_id, global_state=None): super().__init__(module, module_id, global_state) self.padding = 4 self.rank = dist.get_rank() self.adj_groups = self.global_state.get('dist_controller').adj_groups def pad_context(self, h, padding=None): padding = self.padding if padding is None else padding share_to_left = h[:, :, :padding].contiguous() share_to_right = h[:, :, -padding:].contiguous() if self.rank % 2: # 1. the rank is odd, pad the left first if self.rank: # not the first rank, have left context padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) left_context = padding_list[0].to(h.device, non_blocking=True) else: left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) # 2. then pad the right if self.rank != dist.get_world_size() - 1: # not the last rank, have right context padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) right_context = padding_list[1].to(h.device, non_blocking=True) else: right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) else: # 1. the rank is even, pad the right first if self.rank != dist.get_world_size() - 1: # not the last rank, have right context padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) right_context = padding_list[1].to(h.device, non_blocking=True) else: right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) # 2. then pad the left if self.rank: # not the first rank, have left context padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) left_context = padding_list[0].to(h.device, non_blocking=True) else: left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) torch.cuda.synchronize() h_with_context = torch.cat([left_context, h, right_context], dim=2) return h_with_context def get_new_forward(self): module = self.module def new_forward(hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor: hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) ) identity = hidden_states hidden_states = self.pad_context(hidden_states) hidden_states = module.conv1(hidden_states) hidden_states = module.conv2(hidden_states) hidden_states = module.conv3(hidden_states) hidden_states = module.conv4(hidden_states) hidden_states = hidden_states[:, :, self.padding:-self.padding] hidden_states = identity + hidden_states hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] ) return hidden_states return new_forward class AttentionPlugin(ModulePlugin): def __init__(self, module, module_id, global_state=None): super().__init__(module, module_id, global_state) self.padding = 24 self.top_k = 16 self.top_k_chunk_size = 24 self.attn_scale = 1. self.token_num_scale = False self.rank = dist.get_rank() self.adj_groups = self.global_state.get('dist_controller').adj_groups self.world_size = self.global_state.get('dist_controller').world_size self.dynamic_scale = False def pad_context(self, h, padding=None): padding = self.padding if padding is None else padding share_to_left = h[:, :padding].contiguous() share_to_right = h[:, -padding:].contiguous() if self.rank % 2: # 1. the rank is odd, pad the left first if self.rank: # not the first rank, have left context padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) left_context = padding_list[0].to(h.device, non_blocking=True) else: left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) # 2. then pad the right if self.rank != dist.get_world_size() - 1: # not the last rank, have right context padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) right_context = padding_list[1].to(h.device, non_blocking=True) else: right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) else: # 1. the rank is even, pad the right first if self.rank != dist.get_world_size() - 1: # not the last rank, have right context padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) right_context = padding_list[1].to(h.device, non_blocking=True) else: right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) # 2. then pad the left if self.rank: # not the first rank, have left context padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) left_context = padding_list[0].to(h.device, non_blocking=True) else: left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) torch.cuda.synchronize() h_with_context = torch.cat([left_context, h, right_context], dim=1) return h_with_context, padding def get_topk(self, q, k, v, top_k=None): # h = [N, F, C] top_k = self.top_k if top_k is None else top_k share_num = int(max(top_k // self.world_size, 0)) stride = max(q.shape[1] // share_num, 1) if share_num else 1000000 topk_indices = torch.arange(0, q.shape[1], stride, device=q.device) k_to_share, v_to_share = k[:, topk_indices], v[:, topk_indices] gather_k = [torch.zeros_like(k_to_share) for _ in range(self.world_size)] gather_v = [torch.zeros_like(v_to_share) for _ in range(self.world_size)] dist.all_gather(gather_k, k_to_share) dist.all_gather(gather_v, v_to_share) gather_k = torch.cat(gather_k, dim=1)[:, :top_k] gather_v = torch.cat(gather_v, dim=1)[:, :top_k] return gather_k, gather_v def gather_context(self, h): self.temporal_n = h.shape[1] stack_list = [torch.zeros_like(h) for _ in range(self.world_size)] dist.all_gather(stack_list, h) return torch.cat(stack_list, dim=1) def get_new_forward(self): module = self.module def new_forward(x, encoder_hidden_states=None, attention_mask=None): context=encoder_hidden_states temporal_n = x.shape[1] q = module.to_q(x) context = x if context is None else context k, v = module.to_k(context), module.to_v(context) b, _, _ = q.shape q, k, v = map( lambda t: t.unsqueeze(3).reshape(b, t.shape[1], module.heads, -1).permute(0, 2, 1, 3).reshape(b*module.heads, t.shape[1], -1), (q, k, v), ) global_k, global_v = self.get_topk(q, k, v) num_global = global_k.shape[1] padded_k, _ = self.pad_context(k) padded_v, padding = self.pad_context(v) padded_k = torch.cat([padded_k, global_k], dim=1) padded_v = torch.cat([padded_v, global_v], dim=1) # if self.is_log_node: # print("Total KV num:", padding*2 + global_k.shape[1], "Global KV num:", num_global, "Padding:", padding) attn_mask = torch.ones(temporal_n, temporal_n + 2*padding + num_global, dtype=q.dtype).to(q.device) for i in range(temporal_n): attn_mask[i, 0: max(0, i)] = float('-inf') attn_mask[i, min(temporal_n+2*padding, i+1+2*padding): temporal_n+2*padding] = float('-inf') if self.dynamic_scale and self.local_phase is not None and self.global_phase is not None: if self.t < self.local_phase['t']: attn_mask[:, temporal_n+2*padding:] += self.local_phase['global_biase'] attn_mask[:, :temporal_n+2*padding] += self.local_phase['local_biase'] if self.t >= self.global_phase['t']: attn_mask[:, temporal_n+2*padding:] += self.global_phase['global_biase'] attn_mask[:, :temporal_n+2*padding] += self.global_phase['local_biase'] out = my_attention( q, padded_k, padded_v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, scale=self.attn_scale, token_num_scale=self.token_num_scale ) out = ( out.unsqueeze(0).reshape(b, module.heads, out.shape[1], -1).permute(0, 2, 1, 3) .reshape(b, out.shape[1], -1) ) # linear proj hidden_states = module.to_out[0](out) hidden_states = module.to_out[1](hidden_states) return hidden_states return new_forward class Conv3DPligin(ModulePlugin): def __init__(self, module, module_id, global_state=None): super().__init__(module, module_id, global_state) self.padding = 1 self.rank = dist.get_rank() self.adj_groups = self.global_state.get('dist_controller').adj_groups def pad_context(self, h): padding = self.padding share_to_left = h[:, :, :padding].contiguous() share_to_right = h[:, :, -padding:].contiguous() if self.rank % 2: # 1. the rank is odd, pad the left first if self.rank: # not the first rank, have left context padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) left_context = padding_list[0].to(h.device, non_blocking=True) else: left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) # 2. then pad the right if self.rank != dist.get_world_size() - 1: # not the last rank, have right context padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) right_context = padding_list[1].to(h.device, non_blocking=True) else: right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) else: # 1. the rank is even, pad the right first if self.rank != dist.get_world_size() - 1: # not the last rank, have right context padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) right_context = padding_list[1].to(h.device, non_blocking=True) else: right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) # 2. then pad the left if self.rank: # not the first rank, have left context padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) left_context = padding_list[0].to(h.device, non_blocking=True) else: left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) torch.cuda.synchronize() h_with_context = torch.cat([left_context, h, right_context], dim=2) return h_with_context def get_new_forward(self): module = self.module def new_forward(hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.pad_context(hidden_states) hidden_states = module.old_forward(hidden_states)[:,:,self.padding:-self.padding] return hidden_states return new_forward class UNetPlugin(ModulePlugin): def __init__(self, module, module_id, global_state=None): super().__init__(module, module_id, global_state) def get_new_forward(self): module = self.module def new_forward(*args, **kwargs): self.global_state.set('timestep', args[1].item()) return module.old_forward(*args, **kwargs) return new_forward ================================================ FILE: src/video_infinity/wrapper.py ================================================ from ..tools import save_generation, GlobalState, DistController from .plugins import torch, ModulePlugin, UNetPlugin, GroupNormPlugin, ConvLayerPlugin, AttentionPlugin, Conv3DPligin, dist class DistWrapper(object): def __init__(self, pipe, dist_controller: DistController, config) -> None: super().__init__() self.pipe = pipe self.dist_controller = dist_controller self.config = config self.global_state = GlobalState({ "dist_controller": dist_controller }) self.plugin_mount() def switch_plugin(self, plugin_name, enable): if plugin_name not in self.plugins: return for moudule_id in self.plugins[plugin_name]: moudle: ModulePlugin = self.plugins[plugin_name][moudule_id] moudle.set_enable(enable) def config_plugin(self, plugin_name, config): if plugin_name not in self.plugins: return for moudule_id in self.plugins[plugin_name]: moudle: ModulePlugin = self.plugins[plugin_name][moudule_id] moudle.update_config(config) def plugin_mount(self): self.plugins = {} self.unet_plugin_mount() self.attn_plugin_mount() # self.group_norm_plugin_mount() # self.conv_3d_plugin_mount() # Conv3d and Conv layer can only be used one at a time self.conv_plugin_mount() def group_norm_plugin_mount(self): self.plugins['group_norm'] = {} group_norms = [] for module in self.pipe.unet.named_modules(): if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'GroupNorm': group_norms.append(module[1]) if self.dist_controller.is_master: print(f'Found {len(group_norms)} group norms') for i, group_norm in enumerate(group_norms): plugin_id = 'group_norm', i self.plugins['group_norm'][plugin_id] = GroupNormPlugin(group_norm, plugin_id, self.global_state) def conv_plugin_mount(self): self.plugins['conv_layer'] = {} convs = [] for module in self.pipe.unet.named_modules(): if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'TemporalConvLayer': convs.append(module[1]) if self.dist_controller.is_master: print(f'Found {len(convs)} convs') for i, conv in enumerate(convs): plugin_id = 'conv_layer', i self.plugins['conv_layer'][plugin_id] = ConvLayerPlugin(conv, plugin_id, self.global_state) def conv_3d_plugin_mount(self): self.plugins['conv_3d'] = {} conv3d_s = [] for module in self.pipe.unet.named_modules(): if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Conv3d': conv3d_s.append(module[1]) if self.dist_controller.is_master: print(f'Found {len(conv3d_s)} conv3d_s') for i, conv in enumerate(conv3d_s): plugin_id = 'conv_3d', i self.plugins['conv_3d'][plugin_id] = Conv3DPligin(conv, plugin_id, self.global_state) def attn_plugin_mount(self): self.plugins['attn'] = {} attns = [] for module in self.pipe.unet.named_modules(): if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Attention': attns.append(module[1]) if self.dist_controller.is_master: print(f'Found {len(attns)} attns') for i, attn in enumerate(attns): plugin_id = 'attn', i self.plugins['attn'][plugin_id] = AttentionPlugin(attn, plugin_id, self.global_state) def unet_plugin_mount(self): self.plugins['unet'] = UNetPlugin( self.pipe.unet, ('unet', 0), self.global_state ) def inference( self, prompts="A beagle wearning diving goggles swimming in the ocean while the camera is moving, coral reefs in the background", config={}, pipe_configs={ "steps": 50, "guidance_scale": 12, "fps": 60, "num_frames": 24 * 1, "height": 320, "width": 512, "export_fps": 12, "base_path": "./work/output", "file_name": None }, plugin_configs={ "attn":{ "padding": 24, "top_k": 24, "top_k_chunk_size": 24, "attn_scale": 1., "token_num_scale": True, "dynamic_scale": True, }, "conv_3d": { "padding": 1, }, "conv_layer": {}, }, additional_info={}, ): self.plugin_mount() generator = torch.Generator("cuda").manual_seed(self.config["seed"] + self.dist_controller.rank) # generator = torch.Generator("cuda").manual_seed(self.config["seed"]) self.global_state.set("plugin_configs", plugin_configs) video_frames = self.pipe( prompts, num_inference_steps=pipe_configs["steps"], guidance_scale=pipe_configs["guidance_scale"], height=pipe_configs['height'], width=pipe_configs['width'], num_frames=pipe_configs['num_frames'], fps=pipe_configs['fps'], generator=generator ).frames[0] video_frames = torch.tensor(video_frames, dtype=torch.float16, device=self.dist_controller.device) print(f"Rank {self.dist_controller.rank} finished inference. Result: {video_frames.shape}") all_frames = [ torch.zeros_like(video_frames, dtype=torch.float16) for _ in range(self.dist_controller.world_size) ] if self.dist_controller.is_master else None dist.gather(video_frames, all_frames, dst=0) if self.dist_controller.is_master: all_frames = torch.cat(all_frames, dim=0).cpu().numpy() save_generation( all_frames, { "prompt": prompts, "pipe_configs": pipe_configs, "plugin_configs": plugin_configs, "additional_info": additional_info }, config["base_path"], pipe_configs["file_name"] )