[
  {
    "path": "README.md",
    "content": "\n<div align=\"center\">\n\n# Video-Infinity\n\n<img src='./assets/VideoGen-Main.png' width='80%' />\n<br>\n<a href=\"https://arxiv.org/abs/2406.16260\"><img src=\"https://img.shields.io/badge/ariXv-2406.16260-A42C25.svg\" alt=\"arXiv\"></a>\n<a  href=\"https://video-infinity.tanzhenxiong.com\"><img src=\"https://img.shields.io/badge/ProjectPage-Video Infinity-376ED2#376ED2.svg\" alt=\"arXiv\"></a>\n</div>\n\n\n> **Video-Infinity: Distributed Long Video Generation**\n> <br>\n> Zhenxiong Tan, \n> [Xingyi Yang](https://adamdad.github.io/), \n> [Songhua Liu](http://121.37.94.87/), \n> and \n> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)\n> <br>\n> [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore\n> <br>\n\n\n## TL;DR (Too Long; Didn't Read)\nVideo-Infinity generates long videos quickly using multiple GPUs without extra training. Feel free to visit our \n[project page](https://video-infinity.tanzhenxiong.com)\n for more information and generated videos.\n\n\n## Features\n* **Distributed 🌐**: Utilizes multiple GPUs to generate long-form videos.\n* **High-Speed  🚀**: Produces 2,300 frames in just 5 minutes.\n* **Training-Free 🎓**: Generates long videos without requiring additional training for existing models.\n\n## Setup\n### Installation Environment\n```bash\nconda create -n video_infinity_vc2 python=3.10\nconda activate video_infinity_vc2\npip install -r requirements.txt\n```\n<!-- ### Download Pretrained Models\nWe provide a diffusers pipeline for [VideoCrafter2](TODO) to generate long videos.\n```bash\nhuggingface-cli download adamdad/videocrafterv2_diffusers\n``` -->\n\n## Usage\n### Quick Start\n- **Basic Usage**\n```bash\npython inference.py --config examples/config.json\n```\n- **Multi-Prompts**\n```bash\npython inference.py --config examples/multi_prompts.json\n```\n- **Single GPU**\n```bash\npython inference.py --config examples/single_gpu.json\n```\n\n### Config\n#### Basic Config\n| Parameter   | Description                            |\n| ----------- | -------------------------------------- |\n| `devices`   | The list of GPU devices to use.        |\n| `base_path` | The path to save the generated videos. |\n\n#### Pipeline Config\n| Parameter    | Description                                                                                          |\n| ------------ | ---------------------------------------------------------------------------------------------------- |\n| `prompts`    | The list of text prompts. **Note**: The number of prompts should be greater than the number of GPUs. |\n| `file_name`  | The name of the generated video.                                                                     |\n| `num_frames` | The number of frames to generate on **each GPU**.                                                    |\n\n#### Video-Infinity Config\n| Parameter              | Description                                                                                                                                                               |\n| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `*.padding`            | The number of local context frames.                                                                                                                                       |\n| `attn.topk`            | The number of global context frames for `Attention` model.                                                                                                                 |\n| `attn.local_phase`     | When the denoise timestep is less than `t`, it bias the attention. This adds a `local_bias` to the local context frames and a `global_bias` to the global context frames. |\n| `attn.global_phase`    | It is similar to `local_phase`. But it bias the attention when the denoise timestep is greater than `t`.                                                                  |\n| `attn.token_num_scale` | If the value is `True`, the scale factor will be rescaled by the number of tokens. Default is `False`. More details can be referred to this [paper](https://arxiv.org/abs/2306.08645).                |\n#### How to Set Config\n- To avoid the loss of high-frequency information, we recommend setting the sum of `padding` and `attn.topk` to be less than 24 (which is similar to the number of the default frames in the `VideoCrafter2` model).\n  - If you wish to have a larger `padding` or `attn.topk`, you should set the `attn.token_num_scale` to `True`.\n- A higher `local_phase.t` and `global_phase.t` will result in more stable videos but may reduce the diversity of the videos.\n- More `padding` will provide more local context.\n- A higher `attn.topk` will bring about overall stability in the videos.\n\n## Citation\n```\n@article{\n  tan2024videoinf,\n  title={Video-Infinity: Distributed Long Video Generation},\n  author={Zhenxiong Tan, Xingyi Yang, Songhua Liu, and Xinchao Wang},\n  journal={arXiv preprint arXiv:2406.16260},\n  year={2024}\n}\n```\n\n## Acknowledgements\nOur project is based on the [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2) model. We would like to thank the authors for their excellent work! ❤️\n"
  },
  {
    "path": "examples/config.json",
    "content": "{\n    \"dtype\": \"torch.float16\",\n    \"devices\": [0,1,2,3],\n    \"seed\": 123,\n    \"master_port\": 29516,\n    \"base_path\": \"./exp\",\n    \"pipe_configs\":{\n        \"prompts\": [\n            \"A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background.\"\n        ],\n        \"steps\": 30,\n        \"guidance_scale\": 12,\n        \"fps\": 24,\n        \"num_frames\": 24,\n        \"height\": 320,\n        \"width\": 512,  \n        \"export_fps\": 8,\n        \"file_name\": null\n    },\n    \"plugin_configs\":{\n        \"attn\":{\n            \"padding\": 8,\n            \"top_k\": 16,\n            \"top_k_chunk_size\": 24,\n            \"attn_scale\": 1.0,\n            \"token_num_scale\": false,\n            \"dynamic_scale\": true,\n            \"local_phase\": {\n                \"t\": 800, \n                \"local_biase\": 10,\n                \"global_biase\": 0\n            },\n            \"global_phase\": {\n                \"t\": 800,\n                \"local_biase\": 0,\n                \"global_biase\": 10\n            }\n        },\n        \"conv_3d\": {\n            \"padding\": 1\n        }, \n        \"conv_layer\": {}\n    }\n}"
  },
  {
    "path": "examples/multi_promts.json",
    "content": "{\n    \"dtype\": \"torch.float16\",\n    \"devices\": [\n        0,\n        1,\n        2,\n        3,\n        4,\n        5,\n        6,\n        7\n    ],\n    \"seed\": 123,\n    \"master_port\": 29516,\n    \"base_path\": \"./exp\",\n    \"pipe_configs\": {\n        \"prompts\": [\n            \"[Ukiyo-e style] A black Akita puppy stands alone under the eaves of a traditional Japanese house, shivering in the rain and looking scared.\",\n            \"[Ukiyo-e style] A girl in a blue kimono comforts a shivering black Akita puppy during a rainy walk through a stone-paved village.\",\n            \"[Ukiyo-e style] A girl in a blue kimono brings an black Akita puppy into her warm home, a traditional wooden Japanese house with sliding doors.\",\n            \"[Ukiyo-e style] A girl in a blue kimono plays with her energetic black Akita puppy in the garden of a traditional Japanese house, throwing a woven ball.\",\n            \"[Ukiyo-e style] A teenager in a blue kimono jogs with her black Akita through a park filled with cherry blossom trees and ancient stone lanterns.\",\n            \"[Ukiyo-e style] A teenage girl in a blue kimono relaxes in a field of wildflowers, reading a scroll while her black Akita rests beside him under a cherry tree.\",\n            \"[Ukiyo-e style] A young girl in a blue kimono celebrates her coming-of-age ceremony with her black Akita, surrounded by festive lanterns and banners.\",\n            \"[Ukiyo-e style] A girl in a blue kimono and her loyal black Akita enjoy a serene sunset walk along the beach, with the silhouette of Mount Fuji in the distance.\"\n        ],\n        \"steps\": 30,\n        \"guidance_scale\": 12,\n        \"fps\": 24,\n        \"num_frames\": 24,\n        \"height\": 320,\n        \"width\": 512,\n        \"export_fps\": 8,\n        \"file_name\": null\n    },\n    \"plugin_configs\": {\n        \"attn\": {\n            \"padding\": 8,\n            \"top_k\": 16,\n            \"top_k_chunk_size\": 24,\n            \"attn_scale\": 1.0,\n            \"token_num_scale\": false,\n            \"dynamic_scale\": true,\n            \"local_phase\": {\n                \"t\": 850,\n                \"local_biase\": 10,\n                \"global_biase\": 0\n            },\n            \"global_phase\": {\n                \"t\": 850,\n                \"local_biase\": 0,\n                \"global_biase\": 10\n            }\n        },\n        \"conv_3d\": {\n            \"padding\": 1\n        },\n        \"conv_layer\": {}\n    }\n}"
  },
  {
    "path": "examples/single_gpu.json",
    "content": "{\n    \"dtype\": \"torch.float16\",\n    \"devices\": [0],\n    \"seed\": 123,\n    \"master_port\": 29516,\n    \"base_path\": \"./exp\",\n    \"pipe_configs\":{\n        \"prompts\": [\n            \"A beagle wearing diving goggles swimming in the ocean while the camera is moving, coral reefs in the background.\"\n        ],\n        \"steps\": 30,\n        \"guidance_scale\": 12,\n        \"fps\": 24,\n        \"num_frames\": 24,\n        \"height\": 320,\n        \"width\": 512,  \n        \"export_fps\": 8,\n        \"file_name\": null\n    },\n    \"plugin_configs\":{\n        \"attn\":{\n            \"padding\": 8,\n            \"top_k\": 16,\n            \"top_k_chunk_size\": 24,\n            \"attn_scale\": 1.0,\n            \"token_num_scale\": false,\n            \"dynamic_scale\": true,\n            \"local_phase\": {\n                \"t\": 800, \n                \"local_biase\": 10,\n                \"global_biase\": 0\n            },\n            \"global_phase\": {\n                \"t\": 800,\n                \"local_biase\": 0,\n                \"global_biase\": 10\n            }\n        },\n        \"conv_3d\": {\n            \"padding\": 1\n        }, \n        \"conv_layer\": {}\n    }\n}"
  },
  {
    "path": "inference.py",
    "content": "import torch\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport time\nimport json\nimport os\n\nfrom src.video_crafter import VideoCrafterPipeline, UNetVideoCrafter\nfrom diffusers.schedulers import DPMSolverMultistepScheduler\n\nfrom src.tools import DistController\nfrom src.video_infinity.wrapper import DistWrapper\n\ndef parse_args():\n    import argparse\n    parser = argparse.ArgumentParser(description=\"Video Infinity Inference\")\n    parser.add_argument(\"--config\", type=str)\n    args = parser.parse_args()\n    return args\n\ndef init_pipeline(config):\n    pipe = VideoCrafterPipeline.from_pretrained(\n        'adamdad/videocrafterv2_diffusers',\n        torch_dtype=torch.float16\n    )\n    pipe.enable_model_cpu_offload(\n        gpu_id=config[\"devices\"][dist.get_rank() % len(config[\"devices\"])],\n    )\n    pipe.enable_vae_slicing()\n    return pipe\n\ndef run_inference(rank, world_size, config):\n    dist_controller = DistController(rank, world_size, config)\n    pipe = init_pipeline(config)\n    dist_pipe = DistWrapper(pipe, dist_controller, config)\n    start = time.time()\n\n    pipe_configs=config['pipe_configs']\n    plugin_configs=config['plugin_configs']\n\n    prompt_id = int(rank / world_size * len(pipe_configs[\"prompts\"]))\n    prompt = pipe_configs[\"prompts\"][prompt_id]\n\n    start = time.time()\n    dist_pipe.inference(\n        prompt,\n        config,\n        pipe_configs,\n        plugin_configs,\n        additional_info={\n            \"full_config\": config,\n        }\n    )\n    print(f\"Rank {rank} finished. Time: {time.time() - start}\")\n\ndef main(config):\n    size = len(config[\"devices\"])\n    processes = []\n\n    if not os.path.exists(config[\"base_path\"]):\n        os.makedirs(config[\"base_path\"])\n\n    for rank, _ in enumerate(config[\"devices\"]):\n        p = mp.Process(target=run_inference, args=(rank, size, config))\n        p.start()\n        processes.append(p)\n\n    for p in processes:\n        p.join()\n\nif __name__ == \"__main__\":\n    mp.set_start_method(\"spawn\")\n\n    with open(parse_args().config, \"r\") as f:\n        config = json.load(f)\n    \n    main(config)"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ndiffusers\ntransformers\nimageio\naccelerate\nffmpeg\npyav\nimageio-ffmpeg"
  },
  {
    "path": "src/tools.py",
    "content": "import json\nimport numpy as np\nimport imageio\nimport os\n\nimport torch\nimport torch.distributed as dist\n\ndef export_to_video(video_frames, output_video_path, fps = 12):\n    # Ensure all frames are NumPy arrays and determine video dimensions from the first frame\n    assert all(isinstance(frame, np.ndarray) for frame in video_frames), \"All video frames must be NumPy arrays.\"\n    # Ensure output_video_path is ending with .mp4\n    if not output_video_path.endswith('.mp4'):\n        output_video_path += '.mp4'\n    # Create a video file at the specified path and write frames to it\n    with imageio.get_writer(output_video_path, fps=fps, format='mp4') as writer:\n        for frame in video_frames:\n            writer.append_data(\n                (frame * 255).astype(np.uint8)\n            )\n\ndef save_generation(video_frames, configs, base_path, file_name=None):\n    if not os.path.exists(base_path):\n        os.makedirs(base_path)\n    p_config = configs[\"pipe_configs\"]\n    frames, steps, fps = p_config[\"num_frames\"], p_config[\"steps\"], p_config[\"fps\"]\n    if not file_name:\n        index = [int(each.split('_')[0]) for each in os.listdir(base_path)]\n        max_idex = max(index) if index else 0\n        idx_str = str(max_idex + 1).zfill(6)\n\n\n        key_info = '_'.join([str(frames), str(steps), str(fps)])\n        file_name = f'{idx_str}_{key_info}'\n\n    with open(f'{base_path}/{file_name}.json', 'w') as f:\n        json.dump(configs, f, indent=4)\n\n    export_to_video(video_frames, f'{base_path}/{file_name}.mp4', fps=p_config[\"export_fps\"])\n\n    return file_name\n\n\nclass GlobalState:\n    def __init__(self, state={}) -> None:\n        self.init_state(state)\n    \n    def init_state(self, state={}):\n        self.state = state\n\n    def set(self, key, value):\n        self.state[key] = value\n\n    def get(self, key, default=None):\n        return self.state.get(key, default)\n    \n\nclass DistController(object):\n    def __init__(self, rank, world_size, config) -> None:\n        super().__init__()\n        self.rank = rank\n        self.world_size = world_size\n        self.config = config\n        self.is_master = rank == 0\n        self.init_dist()\n        self.init_group()\n        self.device = torch.device(f\"cuda:{config['devices'][dist.get_rank()]}\")\n        torch.cuda.set_device(self.device)\n\n    def init_dist(self):\n        print(f\"Rank {self.rank} is running.\")\n        os.environ['MASTER_ADDR'] = '127.0.0.1'\n        os.environ['MASTER_PORT'] = str(self.config.get(\"master_port\") or \"29500\")\n        dist.init_process_group(\"nccl\", rank=self.rank, world_size=self.world_size)\n\n    def init_group(self):\n        self.adj_groups = [dist.new_group([i, i+1]) for i in range(self.world_size-1)]"
  },
  {
    "path": "src/video_crafter.py",
    "content": "import torch\n\nfrom diffusers.models import AutoencoderKL, UNet3DConditionModel\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom diffusers.schedulers import DPMSolverMultistepScheduler\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import TextToVideoSDPipeline\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\n\nclass VideoCrafterPipeline(TextToVideoSDPipeline):\n    @register_to_config\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet3DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        fps_cond: bool = True,\n    ):\n        self.fps_cond = fps_cond\n        super().__init__(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler, \n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        *args,\n        **kwargs,\n    ):\n        fixed_fps = kwargs.pop(\"fps\", 24)\n        def post_function(sample):\n            fps = fixed_fps\n            unet = self.unet\n            if self.fps_cond:\n                fps = torch.tensor([fps], dtype=torch.float64 , device=sample.device)\n                fps_emb = unet.fps_proj(fps)\n                fps_emb = fps_emb.to(sample.device, dtype=unet.dtype)\n                fps_emb = unet.fps_embedding(fps_emb).repeat_interleave(repeats=sample.shape[0], dim=0)\n                sample += fps_emb\n            return sample\n        self.unet.time_embedding.post_act = post_function\n        # kwargs.pop(\"fps\", None)\n        return super().__call__(*args, **kwargs)\n        \n    @classmethod\n    def from_pretrained(\n        cls,\n        pretrained_model_name_or_path: str,\n        **kwargs,\n    ):\n        pipe = TextToVideoSDPipeline.from_pretrained(\"cerspense/zeroscope_v2_576w\", **kwargs)\n        pipe.__class__ = cls\n        pipe.fps_cond = True\n        pipe.unet = UNetVideoCrafter.from_pretrained(pretrained_model_name_or_path, **kwargs)\n        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type=\"sde-dpmsolver++\")\n        return pipe\n\nclass UNetVideoCrafter(UNet3DConditionModel):\n    @register_to_config\n    def __init__(\n        self,\n        sample_size,\n        in_channels,\n        out_channels,\n        down_block_types,\n        up_block_types,\n        block_out_channels,\n        layers_per_block,\n        downsample_padding,\n        mid_block_scale_factor,\n        act_fn,\n        norm_num_groups,\n        norm_eps,\n        cross_attention_dim,\n        attention_head_dim,\n        num_attention_heads,\n        fps_cond: bool = True,\n        **kwargs\n    ):\n        self.fps_cond = fps_cond\n\n        super().__init__(\n            sample_size=sample_size,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            down_block_types=down_block_types,\n            up_block_types=up_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            downsample_padding=downsample_padding,\n            mid_block_scale_factor=mid_block_scale_factor,\n            act_fn=act_fn,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            num_attention_heads=num_attention_heads,\n            **kwargs\n        )\n\n        if self.fps_cond:\n            self.fps_proj = Timesteps(block_out_channels[0], True, 0)\n            self.fps_embedding = TimestepEmbedding(\n                    block_out_channels[0],\n                    block_out_channels[0] * 4,\n                    act_fn=act_fn,\n                )\n"
  },
  {
    "path": "src/video_infinity/__init__.py",
    "content": ""
  },
  {
    "path": "src/video_infinity/plugins.py",
    "content": "import torch\nimport torch.distributed as dist\nimport math\n\ndef my_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, token_num_scale=False) -> torch.Tensor:\n    L, S = query.size(-2), key.size(-2)\n    base_scale_factor = 1 / math.sqrt(query.size(-1)) * (scale if scale is not None else 1.)\n    attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.dtype).to(query.device)\n    if is_causal:\n        assert attn_mask is None\n        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)\n        attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n        attn_bias.to(query.dtype).to(query.device)\n\n    if attn_mask is not None:\n        if attn_mask.dtype == torch.bool:\n            attn_bias.masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n        else:\n            attn_bias += attn_mask.to(query.dtype).to(query.device)\n    \n    no_mask_count = torch.where(attn_bias < -100, 0, 1).sum(1)\n    biased_scale_factor = torch.log(no_mask_count) / torch.log(torch.tensor(16)) if token_num_scale else 1.\n    scale_factor = (base_scale_factor * biased_scale_factor).unsqueeze(-1) if token_num_scale else base_scale_factor\n    attn_weight = query @ key.transpose(-2, -1) \n    attn_weight *= scale_factor\n    attn_weight += attn_bias\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)\n    return attn_weight @ value\n\nclass ModulePlugin:\n    def __init__(self, module, module_id, global_state=None):\n        self.module = module\n        self.module_id = module_id\n        self.global_state = global_state\n        self.enable = True\n        self.implement_forward()\n\n    @property\n    def is_log_node(self):\n        return self.global_state.get('dist_controller').rank == 0 and self.module_id[1] == 0\n\n    @property\n    def t(self):\n        return self.global_state.get('timestep')\n    \n    @property\n    def p(self):\n        return self.t / 1000\n\n    def implement_forward(self):\n        module = self.module\n        if not hasattr(module, \"old_forward\"):\n            module.old_forward = module.forward\n        self.new_forward = self.get_new_forward()\n        def forward(*args, **kwargs):\n            self.update_config() # update config\n            return self.new_forward(*args, **kwargs) if self.enable else self.old_forward(*args, **kwargs)\n        module.forward = forward\n\n    def set_enable(self, enable=True):\n        self.enable = enable\n        \n    def get_new_forward(self):\n        raise NotImplementedError\n    \n    def update_config(self, config:dict=None):\n        if config is None:\n            config = self.global_state.get('plugin_configs', {}).get(self.module_id[0], {})\n        for key, value in config.items():\n            setattr(self, key, value)\n\n\nclass GroupNormPlugin(ModulePlugin):\n    def __init__(self, module, module_id, global_state=None):\n        super().__init__(module, module_id, global_state)\n\n    def get_new_forward(self):\n        module = self.module\n    \n        def new_forward(x):\n            shape = x.shape\n            N, C, G = shape[0], shape[1], module.num_groups\n            assert C % G == 0\n\n            x = x.reshape(N, G, -1)\n            \n            mean = x.mean(-1, keepdim=True)\n            dist.all_reduce(mean)\n            mean = mean / dist.get_world_size()\n            var = ((x - mean) ** 2).mean(-1, keepdim=True) \n            dist.all_reduce(var)\n            var = var / dist.get_world_size()\n\n            x = (x - mean) / (var + module.eps).sqrt()\n            x = x.view(shape)\n\n\n            new_shape = [1 for _ in shape]\n            new_shape[1] = -1\n\n            return x * module.weight.view(new_shape) + module.bias.view(new_shape)\n\n        return new_forward\n\nclass ConvLayerPlugin(ModulePlugin):\n    def __init__(self, module, module_id, global_state=None):\n        super().__init__(module, module_id, global_state)\n        self.padding = 4\n        self.rank = dist.get_rank()\n        self.adj_groups = self.global_state.get('dist_controller').adj_groups\n\n    def pad_context(self, h, padding=None):\n        padding = self.padding if padding is None else padding\n        share_to_left = h[:, :, :padding].contiguous()\n        share_to_right = h[:, :, -padding:].contiguous()\n        if self.rank % 2:\n            # 1. the rank is odd, pad the left first \n            if self.rank:\n                # not the first rank, have left context\n                padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])\n                left_context = padding_list[0].to(h.device, non_blocking=True)\n            else:\n                left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)\n            # 2. then pad the right\n            if self.rank != dist.get_world_size() - 1:\n                # not the last rank, have right context\n                padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])\n                right_context = padding_list[1].to(h.device, non_blocking=True)\n            else:\n                right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)\n        else:\n            # 1. the rank is even, pad the right first\n            if self.rank != dist.get_world_size() - 1:\n                # not the last rank, have right context\n                padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])\n                right_context = padding_list[1].to(h.device, non_blocking=True)\n            else:\n                right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)\n            # 2. then pad the left\n            if self.rank:\n                # not the first rank, have left context\n                padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])\n                left_context = padding_list[0].to(h.device, non_blocking=True)\n            else:\n                left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)\n        torch.cuda.synchronize()\n        h_with_context = torch.cat([left_context, h, right_context], dim=2)\n        return h_with_context\n\n    def get_new_forward(self):\n        module = self.module\n        def new_forward(hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:\n            hidden_states = (\n                hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)\n            )\n\n            identity = hidden_states\n\n            hidden_states = self.pad_context(hidden_states)\n            hidden_states = module.conv1(hidden_states)\n            hidden_states = module.conv2(hidden_states)\n            hidden_states = module.conv3(hidden_states)\n            hidden_states = module.conv4(hidden_states)\n            hidden_states = hidden_states[:, :, self.padding:-self.padding]\n\n\n            hidden_states = identity + hidden_states\n\n            hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(\n                (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]\n            )\n            return hidden_states\n\n        return new_forward\n    \n\nclass AttentionPlugin(ModulePlugin):\n    def __init__(self, module, module_id, global_state=None):\n        super().__init__(module, module_id, global_state)\n        self.padding = 24\n        self.top_k = 16\n        self.top_k_chunk_size = 24\n        self.attn_scale = 1.\n        self.token_num_scale = False\n        self.rank = dist.get_rank()\n        self.adj_groups = self.global_state.get('dist_controller').adj_groups\n        self.world_size = self.global_state.get('dist_controller').world_size\n        self.dynamic_scale = False\n\n    def pad_context(self, h, padding=None):\n        padding = self.padding if padding is None else padding\n\n        share_to_left = h[:, :padding].contiguous()\n        share_to_right = h[:, -padding:].contiguous()\n        if self.rank % 2:\n            # 1. the rank is odd, pad the left first \n            if self.rank:\n                # not the first rank, have left context\n                padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])\n                left_context = padding_list[0].to(h.device, non_blocking=True)\n            else:\n                left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)\n            # 2. then pad the right\n            if self.rank != dist.get_world_size() - 1:\n                # not the last rank, have right context\n                padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])\n                right_context = padding_list[1].to(h.device, non_blocking=True)\n            else:\n                right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)\n        else:\n            # 1. the rank is even, pad the right first\n            if self.rank != dist.get_world_size() - 1:\n                # not the last rank, have right context\n                padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])\n                right_context = padding_list[1].to(h.device, non_blocking=True)\n            else:\n                right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)\n            # 2. then pad the left\n            if self.rank:\n                # not the first rank, have left context\n                padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])\n                left_context = padding_list[0].to(h.device, non_blocking=True)\n            else:\n                left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)\n        torch.cuda.synchronize()\n\n        h_with_context = torch.cat([left_context, h, right_context], dim=1)\n        return h_with_context, padding\n    \n    def get_topk(self, q, k, v, top_k=None):\n        # h = [N, F, C]\n        top_k = self.top_k if top_k is None else top_k\n        share_num = int(max(top_k // self.world_size, 0))\n\n        stride = max(q.shape[1] // share_num, 1) if share_num else 1000000\n\n        topk_indices = torch.arange(0, q.shape[1], stride, device=q.device)\n\n        k_to_share, v_to_share =  k[:, topk_indices], v[:, topk_indices]\n\n        gather_k = [torch.zeros_like(k_to_share) for _ in range(self.world_size)]\n        gather_v = [torch.zeros_like(v_to_share) for _ in range(self.world_size)]\n\n        dist.all_gather(gather_k, k_to_share)\n        dist.all_gather(gather_v, v_to_share)\n\n        gather_k = torch.cat(gather_k, dim=1)[:, :top_k]\n        gather_v = torch.cat(gather_v, dim=1)[:, :top_k]\n\n        return gather_k, gather_v\n\n    def gather_context(self, h):\n        self.temporal_n = h.shape[1]\n        stack_list = [torch.zeros_like(h) for _ in range(self.world_size)]\n        dist.all_gather(stack_list, h)\n        return torch.cat(stack_list, dim=1)\n\n    def get_new_forward(self):\n        module = self.module\n        def new_forward(x, encoder_hidden_states=None, attention_mask=None):\n            context=encoder_hidden_states\n\n            temporal_n = x.shape[1]\n            q = module.to_q(x)\n            \n            context = x if context is None else context\n            k, v = module.to_k(context), module.to_v(context)\n            b, _, _ = q.shape\n            q, k, v = map(\n                lambda t: t.unsqueeze(3).reshape(b, t.shape[1], module.heads, -1).permute(0, 2, 1, 3).reshape(b*module.heads, t.shape[1], -1),\n                (q, k, v),\n            )\n\n            global_k, global_v = self.get_topk(q, k, v)\n            num_global = global_k.shape[1]\n\n            padded_k, _ = self.pad_context(k)\n            padded_v, padding = self.pad_context(v)\n\n            padded_k = torch.cat([padded_k, global_k], dim=1)\n            padded_v = torch.cat([padded_v, global_v], dim=1)\n\n            # if self.is_log_node:\n            #     print(\"Total KV num:\", padding*2 + global_k.shape[1], \"Global KV num:\", num_global, \"Padding:\", padding)\n\n            attn_mask = torch.ones(temporal_n, temporal_n + 2*padding + num_global, dtype=q.dtype).to(q.device)\n            for i in range(temporal_n):\n                attn_mask[i, 0: max(0, i)] = float('-inf')\n                attn_mask[i, min(temporal_n+2*padding, i+1+2*padding): temporal_n+2*padding] = float('-inf')\n                \n            if self.dynamic_scale and self.local_phase is not None and self.global_phase is not None:\n                if self.t < self.local_phase['t']:\n                    attn_mask[:, temporal_n+2*padding:] += self.local_phase['global_biase']\n                    attn_mask[:, :temporal_n+2*padding] += self.local_phase['local_biase']\n                if self.t >= self.global_phase['t']:\n                    attn_mask[:, temporal_n+2*padding:] += self.global_phase['global_biase']\n                    attn_mask[:, :temporal_n+2*padding] += self.global_phase['local_biase']\n            out = my_attention(\n                q, padded_k, padded_v,\n                attn_mask=attn_mask, dropout_p=0.0, is_causal=False,\n                scale=self.attn_scale,\n                token_num_scale=self.token_num_scale\n            )\n\n\n            out = (\n                out.unsqueeze(0).reshape(b, module.heads, out.shape[1], -1).permute(0, 2, 1, 3)\n                .reshape(b, out.shape[1], -1)\n            )\n\n            # linear proj\n            hidden_states = module.to_out[0](out)\n            hidden_states = module.to_out[1](hidden_states)\n            \n            return hidden_states\n\n        return new_forward\n    \n\nclass Conv3DPligin(ModulePlugin):\n    def __init__(self, module, module_id, global_state=None):\n        super().__init__(module, module_id, global_state)\n        self.padding = 1\n        self.rank = dist.get_rank()\n        self.adj_groups = self.global_state.get('dist_controller').adj_groups\n\n    def pad_context(self, h):\n        padding = self.padding\n        share_to_left = h[:, :, :padding].contiguous()\n        share_to_right = h[:, :, -padding:].contiguous()\n        if self.rank % 2:\n            # 1. the rank is odd, pad the left first \n            if self.rank:\n                # not the first rank, have left context\n                padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])\n                left_context = padding_list[0].to(h.device, non_blocking=True)\n            else:\n                left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)\n            # 2. then pad the right\n            if self.rank != dist.get_world_size() - 1:\n                # not the last rank, have right context\n                padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])\n                right_context = padding_list[1].to(h.device, non_blocking=True)\n            else:\n                right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)\n        else:\n            # 1. the rank is even, pad the right first\n            if self.rank != dist.get_world_size() - 1:\n                # not the last rank, have right context\n                padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])\n                right_context = padding_list[1].to(h.device, non_blocking=True)\n            else:\n                right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)\n            # 2. then pad the left\n            if self.rank:\n                # not the first rank, have left context\n                padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]\n                dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])\n                left_context = padding_list[0].to(h.device, non_blocking=True)\n            else:\n                left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)\n        torch.cuda.synchronize()\n        h_with_context = torch.cat([left_context, h, right_context], dim=2)\n        return h_with_context\n\n    def get_new_forward(self):\n        module = self.module\n        def new_forward(hidden_states: torch.Tensor) -> torch.Tensor:\n            hidden_states = self.pad_context(hidden_states)\n            hidden_states = module.old_forward(hidden_states)[:,:,self.padding:-self.padding]\n            return hidden_states\n\n        return new_forward\n\nclass UNetPlugin(ModulePlugin):\n    def __init__(self, module, module_id, global_state=None):\n        super().__init__(module, module_id, global_state)\n\n    def get_new_forward(self):\n        module = self.module\n    \n        def new_forward(*args, **kwargs):\n            self.global_state.set('timestep', args[1].item())\n            return module.old_forward(*args, **kwargs)\n\n        return new_forward"
  },
  {
    "path": "src/video_infinity/wrapper.py",
    "content": "from ..tools import save_generation, GlobalState, DistController\n\nfrom .plugins import torch, ModulePlugin, UNetPlugin, GroupNormPlugin, ConvLayerPlugin, AttentionPlugin, Conv3DPligin, dist\n\nclass DistWrapper(object):\n    def __init__(self, pipe, dist_controller: DistController, config) -> None:\n        super().__init__()\n        self.pipe = pipe\n        self.dist_controller = dist_controller\n        self.config = config\n        self.global_state = GlobalState({\n            \"dist_controller\": dist_controller\n        })\n        self.plugin_mount()\n\n    def switch_plugin(self, plugin_name, enable):\n        if plugin_name not in self.plugins: return\n        for moudule_id in self.plugins[plugin_name]:\n            moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]\n            moudle.set_enable(enable)\n    \n    def config_plugin(self, plugin_name, config):\n        if plugin_name not in self.plugins: return\n        for moudule_id in self.plugins[plugin_name]:\n            moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]\n            moudle.update_config(config)\n\n    \n    def plugin_mount(self):\n        self.plugins = {}\n        self.unet_plugin_mount()\n        self.attn_plugin_mount()\n\n\n        # self.group_norm_plugin_mount()\n        # self.conv_3d_plugin_mount()\n\n        # Conv3d and Conv layer can only be used one at a time\n        self.conv_plugin_mount()\n\n    def group_norm_plugin_mount(self):\n        self.plugins['group_norm'] = {}\n        group_norms = []\n        for module in self.pipe.unet.named_modules():\n            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'GroupNorm':\n                group_norms.append(module[1])\n        if self.dist_controller.is_master:\n            print(f'Found {len(group_norms)} group norms')\n        for i, group_norm in enumerate(group_norms):\n            plugin_id = 'group_norm', i\n            self.plugins['group_norm'][plugin_id] = GroupNormPlugin(group_norm, plugin_id, self.global_state)\n            \n    def conv_plugin_mount(self):\n        self.plugins['conv_layer'] = {}\n        convs = []\n        for module in self.pipe.unet.named_modules():\n            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'TemporalConvLayer':\n                convs.append(module[1])\n        if self.dist_controller.is_master:\n            print(f'Found {len(convs)} convs')\n        for i, conv in enumerate(convs):\n            plugin_id = 'conv_layer', i\n            self.plugins['conv_layer'][plugin_id] = ConvLayerPlugin(conv, plugin_id, self.global_state)\n\n    def conv_3d_plugin_mount(self):\n        self.plugins['conv_3d'] = {}\n        conv3d_s = []\n        for module in self.pipe.unet.named_modules():\n            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Conv3d':\n                conv3d_s.append(module[1])\n        if self.dist_controller.is_master:\n            print(f'Found {len(conv3d_s)} conv3d_s')\n        for i, conv in enumerate(conv3d_s):\n            plugin_id = 'conv_3d', i\n            self.plugins['conv_3d'][plugin_id] = Conv3DPligin(conv, plugin_id, self.global_state)\n\n\n    def attn_plugin_mount(self):\n        self.plugins['attn'] = {}\n        attns = []\n        for module in self.pipe.unet.named_modules():\n            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Attention':\n                attns.append(module[1])\n        if self.dist_controller.is_master:\n            print(f'Found {len(attns)} attns')\n        for i, attn in enumerate(attns):\n            plugin_id = 'attn', i\n            self.plugins['attn'][plugin_id] = AttentionPlugin(attn, plugin_id, self.global_state)\n\n    def unet_plugin_mount(self):\n        self.plugins['unet'] = UNetPlugin(\n            self.pipe.unet,\n            ('unet', 0),\n            self.global_state\n        )\n    \n    def inference(\n        self,\n        prompts=\"A beagle wearning diving goggles  swimming in the ocean while the camera is moving, coral reefs in the background\",\n        config={},\n        pipe_configs={\n            \"steps\": 50,\n            \"guidance_scale\": 12,\n            \"fps\": 60,\n            \"num_frames\": 24 * 1,\n            \"height\": 320,\n            \"width\": 512,\n            \"export_fps\": 12,\n            \"base_path\": \"./work/output\",\n            \"file_name\": None\n        },\n        plugin_configs={\n            \"attn\":{\n                \"padding\": 24,\n                \"top_k\": 24,\n                \"top_k_chunk_size\": 24,\n                \"attn_scale\": 1.,\n                \"token_num_scale\": True,\n                \"dynamic_scale\": True,\n            },\n            \"conv_3d\": {\n                \"padding\": 1,\n            }, \n            \"conv_layer\": {},\n        },\n        additional_info={},\n    ):\n        self.plugin_mount()\n        generator = torch.Generator(\"cuda\").manual_seed(self.config[\"seed\"] + self.dist_controller.rank)\n        # generator = torch.Generator(\"cuda\").manual_seed(self.config[\"seed\"])\n\n        self.global_state.set(\"plugin_configs\", plugin_configs)\n\n        video_frames = self.pipe(\n            prompts, \n            num_inference_steps=pipe_configs[\"steps\"], \n            guidance_scale=pipe_configs[\"guidance_scale\"],\n            height=pipe_configs['height'], \n            width=pipe_configs['width'], \n            num_frames=pipe_configs['num_frames'], \n            fps=pipe_configs['fps'],\n            generator=generator\n        ).frames[0]\n\n        video_frames = torch.tensor(video_frames, dtype=torch.float16, device=self.dist_controller.device)\n\n        print(f\"Rank {self.dist_controller.rank} finished inference. Result: {video_frames.shape}\")\n        all_frames = [\n            torch.zeros_like(video_frames, dtype=torch.float16) for _ in range(self.dist_controller.world_size)\n        ] if self.dist_controller.is_master else None\n        dist.gather(video_frames, all_frames, dst=0)\n        if self.dist_controller.is_master:\n            all_frames = torch.cat(all_frames, dim=0).cpu().numpy()\n            save_generation(\n                all_frames, \n                {\n                    \"prompt\": prompts,\n                    \"pipe_configs\": pipe_configs,\n                    \"plugin_configs\": plugin_configs,\n                    \"additional_info\": additional_info\n                },\n                config[\"base_path\"],\n                pipe_configs[\"file_name\"]\n            )\n"
  }
]