[
  {
    "path": ".gitignore",
    "content": "__pycache__\n.vscode\nsamples\nxformers\nsrc\nthird_party\nbackup\npretrained_models\n*.nfs*\n./*.png\n./*.mp4\ndemo/tmp\ndemo/outputs"
  },
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright 2023 MagicAnimate Team All rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "README.md",
    "content": "<!-- # magic-edit.github.io -->\n\n<p align=\"center\">\n\n  <h2 align=\"center\">MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h2>\n  <p align=\"center\">\n    <a href=\"https://scholar.google.com/citations?user=-4iADzMAAAAJ&hl=en\"><strong>Zhongcong Xu</strong></a>\n    ·\n    <a href=\"http://jeff95.me/\"><strong>Jianfeng Zhang</strong></a>\n    ·\n    <a href=\"https://scholar.google.com.sg/citations?user=8gm-CYYAAAAJ&hl=en\"><strong>Jun Hao Liew</strong></a>\n    ·\n    <a href=\"https://hanshuyan.github.io/\"><strong>Hanshu Yan</strong></a>\n    ·\n    <a href=\"https://scholar.google.com/citations?user=stQQf7wAAAAJ&hl=en\"><strong>Jia-Wei Liu</strong></a>\n    ·\n    <a href=\"https://zhangchenxu528.github.io/\"><strong>Chenxu Zhang</strong></a>\n    ·\n    <a href=\"https://sites.google.com/site/jshfeng/home\"><strong>Jiashi Feng</strong></a>\n    ·\n    <a href=\"https://sites.google.com/view/showlab\"><strong>Mike Zheng Shou</strong></a>\n    <br>\n    <br>\n        <a href=\"https://arxiv.org/abs/2311.16498\"><img src='https://img.shields.io/badge/arXiv-MagicAnimate-red' alt='Paper PDF'></a>\n        <a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>\n        <a href='https://huggingface.co/spaces/zcxu-eric/magicanimate'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>\n    <br>\n    <b>National University of Singapore &nbsp; | &nbsp;  ByteDance</b>\n  </p>\n  \n  <table align=\"center\">\n    <tr>\n    <td>\n      <img src=\"assets/teaser/t4.gif\">\n    </td>\n    <td>\n      <img src=\"assets/teaser/t2.gif\">\n    </td>\n    </tr>\n  </table>\n\n## 📢 News\n* **[2023.12.4]** Release inference code and gradio demo. We are working to improve MagicAnimate, stay tuned!\n* **[2023.11.23]** Release MagicAnimate paper and project page.\n\n## 🏃‍♂️ Getting Started\nDownload the pretrained base models for [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [MSE-finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse).\n\nDownload our MagicAnimate [checkpoints](https://huggingface.co/zcxu-eric/MagicAnimate).\n\nPlease follow the huggingface download instructions to download the above models and checkpoints, `git lfs` is recommended.\n\nPlace the based models and checkpoints as follows:\n```bash\nmagic-animate\n|----pretrained_models\n  |----MagicAnimate\n    |----appearance_encoder\n      |----diffusion_pytorch_model.safetensors\n      |----config.json\n    |----densepose_controlnet\n      |----diffusion_pytorch_model.safetensors\n      |----config.json\n    |----temporal_attention\n      |----temporal_attention.ckpt\n  |----sd-vae-ft-mse\n    |----config.json\n    |----diffusion_pytorch_model.safetensors\n  |----stable-diffusion-v1-5\n    |----scheduler\n       |----scheduler_config.json\n    |----text_encoder\n       |----config.json\n       |----pytorch_model.bin\n    |----tokenizer (all)\n    |----unet\n       |----diffusion_pytorch_model.bin\n       |----config.json\n    |----v1-5-pruned-emaonly.safetensors\n|----...\n```\n\n## ⚒️ Installation\nprerequisites: `python>=3.8`, `CUDA>=11.3`, and `ffmpeg`.\n\nInstall with `conda`: \n```bash\nconda env create -f environment.yaml\nconda activate manimate\n```\nor `pip`:\n```bash\npip3 install -r requirements.txt\n```\n\n## 💃 Inference\nRun inference on single GPU:\n```bash\nbash scripts/animate.sh\n```\nRun inference with multiple GPUs:\n```bash\nbash scripts/animate_dist.sh\n```\n\n## 🎨 Gradio Demo \n\n#### Online Gradio Demo:\nTry our [online gradio demo](https://huggingface.co/spaces/zcxu-eric/magicanimate) quickly.\n\n#### Local Gradio Demo:\nLaunch local gradio demo on single GPU:\n```bash\npython3 -m demo.gradio_animate\n```\nLaunch local gradio demo if you have multiple GPUs:\n```bash\npython3 -m demo.gradio_animate_dist\n```\nThen open gradio demo in local browser.\n\n## 🙏 Acknowledgements\nWe would like to thank [AK(@_akhaliq)](https://twitter.com/_akhaliq?lang=en) and huggingface team for the help of setting up oneline gradio demo.\n\n## 🎓 Citation\nIf you find this codebase useful for your research, please use the following entry.\n```BibTeX\n@inproceedings{xu2023magicanimate,\n    author    = {Xu, Zhongcong and Zhang, Jianfeng and Liew, Jun Hao and Yan, Hanshu and Liu, Jia-Wei and Zhang, Chenxu and Feng, Jiashi and Shou, Mike Zheng},\n    title     = {MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model},\n    booktitle = {arXiv},\n    year      = {2023}\n}\n```\n\n"
  },
  {
    "path": "configs/inference/inference.yaml",
    "content": "unet_additional_kwargs:\n  unet_use_cross_frame_attention: false\n  unet_use_temporal_attention: false\n  use_motion_module: true\n  motion_module_resolutions:\n  - 1\n  - 2\n  - 4\n  - 8\n  motion_module_mid_block: false\n  motion_module_decoder_only: false\n  motion_module_type: Vanilla\n  motion_module_kwargs:\n    num_attention_heads: 8\n    num_transformer_block: 1\n    attention_block_types:\n    - Temporal_Self\n    - Temporal_Self\n    temporal_position_encoding: true\n    temporal_position_encoding_max_len: 24\n    temporal_attention_dim_div: 1\n\nnoise_scheduler_kwargs:\n  beta_start: 0.00085\n  beta_end: 0.012\n  beta_schedule: \"linear\"\n"
  },
  {
    "path": "configs/prompts/animation.yaml",
    "content": "pretrained_model_path: \"pretrained_models/stable-diffusion-v1-5\"\npretrained_vae_path: \"pretrained_models/sd-vae-ft-mse\"\npretrained_controlnet_path: \"pretrained_models/MagicAnimate/densepose_controlnet\"\npretrained_appearance_encoder_path: \"pretrained_models/MagicAnimate/appearance_encoder\"\npretrained_unet_path: \"\"\n\nmotion_module: \"pretrained_models/MagicAnimate/temporal_attention/temporal_attention.ckpt\"\n\nsavename: null\n\nfusion_blocks: \"midup\"\n\nseed:           [1]\nsteps:          25\nguidance_scale: 7.5\n\nsource_image:\n  - \"inputs/applications/source_image/monalisa.png\"\n  - \"inputs/applications/source_image/demo4.png\"\n  - \"inputs/applications/source_image/dalle2.jpeg\"\n  - \"inputs/applications/source_image/dalle8.jpeg\"\n  - \"inputs/applications/source_image/multi1_source.png\"\nvideo_path:\n  - \"inputs/applications/driving/densepose/running.mp4\"\n  - \"inputs/applications/driving/densepose/demo4.mp4\"\n  - \"inputs/applications/driving/densepose/running2.mp4\"\n  - \"inputs/applications/driving/densepose/dancing2.mp4\"\n  - \"inputs/applications/driving/densepose/multi_dancing.mp4\"\n\ninference_config: \"configs/inference/inference.yaml\"\nsize: 512\nL:    16\nS:    1 \nI:    0\nclip: 0\noffset: 0\nmax_length: null\nvideo_type: \"condition\"\ninvert_video: false\nsave_individual_videos: false\n"
  },
  {
    "path": "demo/animate.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport argparse\nimport argparse\nimport datetime\nimport inspect\nimport os\nimport numpy as np\nfrom PIL import Image\nfrom omegaconf import OmegaConf\nfrom collections import OrderedDict\n\nimport torch\n\nfrom diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler\n\nfrom tqdm import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom magicanimate.models.unet_controlnet import UNet3DConditionModel\nfrom magicanimate.models.controlnet import ControlNetModel\nfrom magicanimate.models.appearance_encoder import AppearanceEncoderModel\nfrom magicanimate.models.mutual_self_attention import ReferenceAttentionControl\nfrom magicanimate.pipelines.pipeline_animation import AnimationPipeline\nfrom magicanimate.utils.util import save_videos_grid\nfrom accelerate.utils import set_seed\n\nfrom magicanimate.utils.videoreader import VideoReader\n\nfrom einops import rearrange, repeat\n\nimport csv, pdb, glob\nfrom safetensors import safe_open\nimport math\nfrom pathlib import Path\n\nclass MagicAnimate():\n    def __init__(self, config=\"configs/prompts/animation.yaml\") -> None:\n        print(\"Initializing MagicAnimate Pipeline...\")\n        *_, func_args = inspect.getargvalues(inspect.currentframe())\n        func_args = dict(func_args)\n        \n        config  = OmegaConf.load(config)\n        \n        inference_config = OmegaConf.load(config.inference_config)\n            \n        motion_module = config.motion_module\n       \n        ### >>> create animation pipeline >>> ###\n        tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder=\"tokenizer\")\n        text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder=\"text_encoder\")\n        if config.pretrained_unet_path:\n            unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))\n        else:\n            unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder=\"unet\", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))\n        self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder=\"appearance_encoder\").cuda()\n        self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)\n        self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)\n        if config.pretrained_vae_path is not None:\n            vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)\n        else:\n            vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder=\"vae\")\n\n        ### Load controlnet\n        controlnet   = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)\n\n        vae.to(torch.float16)\n        unet.to(torch.float16)\n        text_encoder.to(torch.float16)\n        controlnet.to(torch.float16)\n        self.appearance_encoder.to(torch.float16)\n        \n        unet.enable_xformers_memory_efficient_attention()\n        self.appearance_encoder.enable_xformers_memory_efficient_attention()\n        controlnet.enable_xformers_memory_efficient_attention()\n\n        self.pipeline = AnimationPipeline(\n            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,\n            scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),\n            # NOTE: UniPCMultistepScheduler\n        ).to(\"cuda\")\n\n        # 1. unet ckpt\n        # 1.1 motion module\n        motion_module_state_dict = torch.load(motion_module, map_location=\"cpu\")\n        if \"global_step\" in motion_module_state_dict: func_args.update({\"global_step\": motion_module_state_dict[\"global_step\"]})\n        motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict\n        try:\n            # extra steps for self-trained models\n            state_dict = OrderedDict()\n            for key in motion_module_state_dict.keys():\n                if key.startswith(\"module.\"):\n                    _key = key.split(\"module.\")[-1]\n                    state_dict[_key] = motion_module_state_dict[key]\n                else:\n                    state_dict[key] = motion_module_state_dict[key]\n            motion_module_state_dict = state_dict\n            del state_dict\n            missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)\n            assert len(unexpected) == 0\n        except:\n            _tmp_ = OrderedDict()\n            for key in motion_module_state_dict.keys():\n                if \"motion_modules\" in key:\n                    if key.startswith(\"unet.\"):\n                        _key = key.split('unet.')[-1]\n                        _tmp_[_key] = motion_module_state_dict[key]\n                    else:\n                        _tmp_[key] = motion_module_state_dict[key]\n            missing, unexpected = unet.load_state_dict(_tmp_, strict=False)\n            assert len(unexpected) == 0\n            del _tmp_\n        del motion_module_state_dict\n\n        self.pipeline.to(\"cuda\")\n        self.L = config.L\n        \n        print(\"Initialization Done!\")\n        \n    def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):\n            prompt = n_prompt = \"\"\n            random_seed = int(random_seed)\n            step = int(step)\n            guidance_scale = float(guidance_scale)\n            samples_per_video = []\n            # manually set random seed for reproduction\n            if random_seed != -1: \n                torch.manual_seed(random_seed)\n                set_seed(random_seed)\n            else:\n                torch.seed()\n\n            if motion_sequence.endswith('.mp4'):\n                control = VideoReader(motion_sequence).read()\n                if control[0].shape[0] != size:\n                    control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]\n                control = np.array(control)\n            \n            if source_image.shape[0] != size:\n                source_image = np.array(Image.fromarray(source_image).resize((size, size)))\n            H, W, C = source_image.shape\n            \n            init_latents = None\n            original_length = control.shape[0]\n            if control.shape[0] % self.L > 0:\n                control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')\n            generator = torch.Generator(device=torch.device(\"cuda:0\"))\n            generator.manual_seed(torch.initial_seed())\n            sample = self.pipeline(\n                prompt,\n                negative_prompt         = n_prompt,\n                num_inference_steps     = step,\n                guidance_scale          = guidance_scale,\n                width                   = W,\n                height                  = H,\n                video_length            = len(control),\n                controlnet_condition    = control,\n                init_latents            = init_latents,\n                generator               = generator,\n                appearance_encoder       = self.appearance_encoder, \n                reference_control_writer = self.reference_control_writer,\n                reference_control_reader = self.reference_control_reader,\n                source_image             = source_image,\n            ).videos\n\n            source_images = np.array([source_image] * original_length)\n            source_images = rearrange(torch.from_numpy(source_images), \"t h w c -> 1 c t h w\") / 255.0\n            samples_per_video.append(source_images)\n            \n            control = control / 255.0\n            control = rearrange(control, \"t h w c -> 1 c t h w\")\n            control = torch.from_numpy(control)\n            samples_per_video.append(control[:, :, :original_length])\n\n            samples_per_video.append(sample[:, :, :original_length])\n\n            samples_per_video = torch.cat(samples_per_video)\n\n            time_str = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n            savedir = f\"demo/outputs\"\n            animation_path = f\"{savedir}/{time_str}.mp4\"\n\n            os.makedirs(savedir, exist_ok=True)\n            save_videos_grid(samples_per_video, animation_path)\n            \n            return animation_path\n            "
  },
  {
    "path": "demo/animate_dist.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport argparse\nimport argparse\nimport datetime\nimport inspect\nimport os\nimport numpy as np\nfrom PIL import Image\nfrom omegaconf import OmegaConf\nfrom collections import OrderedDict\n\nimport torch\nimport random\nfrom diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler\n\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom magicanimate.models.unet_controlnet import UNet3DConditionModel\nfrom magicanimate.models.controlnet import ControlNetModel\nfrom magicanimate.models.appearance_encoder import AppearanceEncoderModel\nfrom magicanimate.models.mutual_self_attention import ReferenceAttentionControl\nfrom magicanimate.pipelines.pipeline_animation import AnimationPipeline\nfrom magicanimate.utils.util import save_videos_grid\nfrom magicanimate.utils.dist_tools import distributed_init\nfrom accelerate.utils import set_seed\n\nfrom magicanimate.utils.videoreader import VideoReader\n\nfrom einops import rearrange\n\nanimator = None\n\nclass MagicAnimate():\n    def __init__(self, args) -> None:\n        config=args.config\n        device = torch.device(f\"cuda:{args.rank}\")\n        print(\"Initializing MagicAnimate Pipeline...\")\n        *_, func_args = inspect.getargvalues(inspect.currentframe())\n        func_args = dict(func_args)\n        \n        config  = OmegaConf.load(config)\n        \n        inference_config = OmegaConf.load(config.inference_config)\n            \n        motion_module = config.motion_module\n       \n        ### >>> create animation pipeline >>> ###\n        tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder=\"tokenizer\")\n        text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder=\"text_encoder\")\n        if config.pretrained_unet_path:\n            unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))\n        else:\n            unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder=\"unet\", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))\n        self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder=\"appearance_encoder\").to(device)\n        self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)\n        self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)\n        if config.pretrained_vae_path is not None:\n            vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)\n        else:\n            vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder=\"vae\")\n\n        ### Load controlnet\n        controlnet   = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)\n\n        vae.to(torch.float16)\n        unet.to(torch.float16)\n        text_encoder.to(torch.float16)\n        controlnet.to(torch.float16)\n        self.appearance_encoder.to(torch.float16)\n        \n        unet.enable_xformers_memory_efficient_attention()\n        self.appearance_encoder.enable_xformers_memory_efficient_attention()\n        controlnet.enable_xformers_memory_efficient_attention()\n\n        self.pipeline = AnimationPipeline(\n            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,\n            scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),\n            # NOTE: UniPCMultistepScheduler\n        )\n\n        # 1. unet ckpt\n        # 1.1 motion module\n        motion_module_state_dict = torch.load(motion_module, map_location=\"cpu\")\n        if \"global_step\" in motion_module_state_dict: func_args.update({\"global_step\": motion_module_state_dict[\"global_step\"]})\n        motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict\n        try:\n            # extra steps for self-trained models\n            state_dict = OrderedDict()\n            for key in motion_module_state_dict.keys():\n                if key.startswith(\"module.\"):\n                    _key = key.split(\"module.\")[-1]\n                    state_dict[_key] = motion_module_state_dict[key]\n                else:\n                    state_dict[key] = motion_module_state_dict[key]\n            motion_module_state_dict = state_dict\n            del state_dict\n            missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)\n            assert len(unexpected) == 0\n        except:\n            _tmp_ = OrderedDict()\n            for key in motion_module_state_dict.keys():\n                if \"motion_modules\" in key:\n                    if key.startswith(\"unet.\"):\n                        _key = key.split('unet.')[-1]\n                        _tmp_[_key] = motion_module_state_dict[key]\n                    else:\n                        _tmp_[key] = motion_module_state_dict[key]\n            missing, unexpected = unet.load_state_dict(_tmp_, strict=False)\n            assert len(unexpected) == 0\n            del _tmp_\n        del motion_module_state_dict\n\n        self.pipeline.to(device)\n        self.L = config.L\n        \n        print(\"Initialization Done!\")\n        dist_kwargs = {\"rank\":args.rank, \"world_size\":args.world_size, \"dist\":args.dist}\n        self.predict(args.reference_image, args.motion_sequence, args.random_seed, args.step, args.guidance_scale, args.save_path, dist_kwargs)\n        \n    def predict(self, source_image, motion_sequence, random_seed, step, guidance_scale, save_path, dist_kwargs, size=512):\n            prompt = n_prompt = \"\"\n            samples_per_video = []\n            # manually set random seed for reproduction\n            if random_seed != -1: \n                torch.manual_seed(random_seed)\n                set_seed(random_seed)\n            else:\n                torch.seed()\n\n            if motion_sequence.endswith('.mp4'):\n                control = VideoReader(motion_sequence).read()\n                if control[0].shape[0] != size:\n                    control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]\n                control = np.array(control)\n            if not isinstance(source_image, np.ndarray):\n                source_image = np.array(Image.open(source_image))\n            if source_image.shape[0] != size:\n                source_image = np.array(Image.fromarray(source_image).resize((size, size)))\n            H, W, C = source_image.shape\n            \n            init_latents = None\n            original_length = control.shape[0]\n            if control.shape[0] % self.L > 0:\n                control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')\n            generator = torch.Generator(device=torch.device(\"cuda:0\"))\n            generator.manual_seed(torch.initial_seed())\n            sample = self.pipeline(\n                prompt,\n                negative_prompt         = n_prompt,\n                num_inference_steps     = step,\n                guidance_scale          = guidance_scale,\n                width                   = W,\n                height                  = H,\n                video_length            = len(control),\n                controlnet_condition    = control,\n                init_latents            = init_latents,\n                generator               = generator,\n                appearance_encoder       = self.appearance_encoder, \n                reference_control_writer = self.reference_control_writer,\n                reference_control_reader = self.reference_control_reader,\n                source_image             = source_image,\n                **dist_kwargs,\n            ).videos\n            if dist_kwargs.get('rank', 0) == 0:\n                source_images = np.array([source_image] * original_length)\n                source_images = rearrange(torch.from_numpy(source_images), \"t h w c -> 1 c t h w\") / 255.0\n                samples_per_video.append(source_images)\n                \n                control = control / 255.0\n                control = rearrange(control, \"t h w c -> 1 c t h w\")\n                control = torch.from_numpy(control)\n                samples_per_video.append(control[:, :, :original_length])\n\n                samples_per_video.append(sample[:, :, :original_length])\n\n                samples_per_video = torch.cat(samples_per_video)\n                \n                save_videos_grid(samples_per_video, save_path)\n                \n\ndef distributed_main(device_id, args):\n    args.rank = device_id\n    args.device_id = device_id\n    if torch.cuda.is_available():\n        torch.cuda.set_device(args.device_id)\n        torch.cuda.init()\n    distributed_init(args)\n    MagicAnimate(args)\n\n\ndef run(args):\n\n    if args.dist:\n        args.world_size = max(1, torch.cuda.device_count())\n        assert args.world_size <= torch.cuda.device_count()\n\n        if args.world_size > 0 and torch.cuda.device_count() > 1:\n            port = random.randint(10000, 20000)\n            args.init_method = f\"tcp://localhost:{port}\"\n            torch.multiprocessing.spawn(\n                fn=distributed_main,\n                args=(args,),\n                nprocs=args.world_size,\n            )\n    else:\n        MagicAnimate(args)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", type=str, default=\"configs/prompts/animation.yaml\", required=False)\n    parser.add_argument(\"--dist\", type=bool, default=True, required=False)\n    parser.add_argument(\"--rank\", type=int, default=0, required=False)\n    parser.add_argument(\"--world_size\", type=int, default=1, required=False)\n    parser.add_argument(\"--reference_image\", type=str, default=None, required=True)\n    parser.add_argument(\"--motion_sequence\", type=str, default=None, required=True)\n    parser.add_argument(\"--random_seed\", type=int, default=1, required=False)\n    parser.add_argument(\"--step\", type=int, default=25, required=False)\n    parser.add_argument(\"--guidance_scale\", type=float, default=7.5, required=False)\n    parser.add_argument(\"--save_path\", type=str, default=None, required=True)\n    args = parser.parse_args()\n    run(args)"
  },
  {
    "path": "demo/gradio_animate.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport argparse\nimport imageio\nimport numpy as np\nimport gradio as gr\nfrom PIL import Image\n\nfrom demo.animate import MagicAnimate\n\nanimator = MagicAnimate()\n\ndef animate(reference_image, motion_sequence_state, seed, steps, guidance_scale):\n    return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale)\n\nwith gr.Blocks() as demo:\n\n    gr.HTML(\n        \"\"\"\n        <div style=\"display: flex; justify-content: center; align-items: center; text-align: center;\">\n        <a href=\"https://github.com/magic-research/magic-animate\" style=\"margin-right: 20px; text-decoration: none; display: flex; align-items: center;\">\n        </a>\n        <div>\n            <h1 >MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h1>\n            <h5 style=\"margin: 0;\">If you like our project, please give us a star ✨ on Github for the latest update.</h5>\n            <div style=\"display: flex; justify-content: center; align-items: center; text-align: center;>\n                <a href=\"https://arxiv.org/abs/2311.16498\"><img src=\"https://img.shields.io/badge/Arxiv-2311.16498-red\"></a>\n                <a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>\n                <a href='https://github.com/magic-research/magic-animate'><img src='https://img.shields.io/badge/Github-Code-blue'></a>\n            </div>\n        </div>\n        </div>\n        \"\"\")\n    animation = gr.Video(format=\"mp4\", label=\"Animation Results\", autoplay=True)\n    \n    with gr.Row():\n        reference_image  = gr.Image(label=\"Reference Image\")\n        motion_sequence  = gr.Video(format=\"mp4\", label=\"Motion Sequence\")\n        \n        with gr.Column():\n            random_seed         = gr.Textbox(label=\"Random seed\", value=1, info=\"default: -1\")\n            sampling_steps      = gr.Textbox(label=\"Sampling steps\", value=25, info=\"default: 25\")\n            guidance_scale      = gr.Textbox(label=\"Guidance scale\", value=7.5, info=\"default: 7.5\")\n            submit              = gr.Button(\"Animate\")\n\n    def read_video(video):\n        reader = imageio.get_reader(video)\n        fps = reader.get_meta_data()['fps']\n        return video\n    \n    def read_image(image, size=512):\n        return np.array(Image.fromarray(image).resize((size, size)))\n    \n    # when user uploads a new video\n    motion_sequence.upload(\n        read_video,\n        motion_sequence,\n        motion_sequence\n    )\n    # when `first_frame` is updated\n    reference_image.upload(\n        read_image,\n        reference_image,\n        reference_image\n    )\n    # when the `submit` button is clicked\n    submit.click(\n        animate,\n        [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], \n        animation\n    )\n\n    # Examples\n    gr.Markdown(\"## Examples\")\n    gr.Examples(\n        examples=[\n            [\"inputs/applications/source_image/monalisa.png\", \"inputs/applications/driving/densepose/running.mp4\"], \n            [\"inputs/applications/source_image/demo4.png\", \"inputs/applications/driving/densepose/demo4.mp4\"],\n            [\"inputs/applications/source_image/dalle2.jpeg\", \"inputs/applications/driving/densepose/running2.mp4\"],\n            [\"inputs/applications/source_image/dalle8.jpeg\", \"inputs/applications/driving/densepose/dancing2.mp4\"],\n            [\"inputs/applications/source_image/multi1_source.png\", \"inputs/applications/driving/densepose/multi_dancing.mp4\"],\n        ],\n        inputs=[reference_image, motion_sequence],\n        outputs=animation,\n    )\n\n\ndemo.launch()\n"
  },
  {
    "path": "demo/gradio_animate_dist.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport argparse\nimport imageio\nimport os, datetime\nimport numpy as np\nimport gradio as gr\nfrom PIL import Image\nfrom subprocess import PIPE, run\n\nos.makedirs(\"./demo/tmp\", exist_ok=True)\nsavedir = f\"demo/outputs\"\nos.makedirs(savedir, exist_ok=True)\n\ndef animate(reference_image, motion_sequence, seed, steps, guidance_scale):\n    time_str = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n    animation_path = f\"{savedir}/{time_str}.mp4\"\n    save_path = \"./demo/tmp/input_reference_image.png\"\n    Image.fromarray(reference_image).save(save_path)\n    command = \"python -m demo.animate_dist --reference_image {} --motion_sequence {} --random_seed {} --step {} --guidance_scale {} --save_path {}\".format(\n        save_path,\n        motion_sequence,\n        seed,\n        steps,\n        guidance_scale,\n        animation_path\n    )\n    run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True)\n    return animation_path\n\nwith gr.Blocks() as demo:\n\n    gr.HTML(\n        \"\"\"\n        <div style=\"display: flex; justify-content: center; align-items: center; text-align: center;\">\n        <a href=\"https://github.com/magic-research/magic-animate\" style=\"margin-right: 20px; text-decoration: none; display: flex; align-items: center;\">\n        </a>\n        <div>\n            <h1 >MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h1>\n            <h5 style=\"margin: 0;\">If you like our project, please give us a star ✨ on Github for the latest update.</h5>\n            <div style=\"display: flex; justify-content: center; align-items: center; text-align: center;>\n                <a href=\"https://arxiv.org/abs/2311.16498\"><img src=\"https://img.shields.io/badge/Arxiv-2311.16498-red\"></a>\n                <a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>\n                <a href='https://github.com/magic-research/magic-animate'><img src='https://img.shields.io/badge/Github-Code-blue'></a>\n            </div>\n        </div>\n        </div>\n        \"\"\")\n    animation = gr.Video(format=\"mp4\", label=\"Animation Results\", autoplay=True)\n    \n    with gr.Row():\n        reference_image  = gr.Image(label=\"Reference Image\")\n        motion_sequence  = gr.Video(format=\"mp4\", label=\"Motion Sequence\")\n        \n        with gr.Column():\n            random_seed         = gr.Textbox(label=\"Random seed\", value=1, info=\"default: -1\")\n            sampling_steps      = gr.Textbox(label=\"Sampling steps\", value=25, info=\"default: 25\")\n            guidance_scale      = gr.Textbox(label=\"Guidance scale\", value=7.5, info=\"default: 7.5\")\n            submit              = gr.Button(\"Animate\")\n\n    def read_video(video, size=512):\n        size = int(size)\n        reader = imageio.get_reader(video)\n        # fps = reader.get_meta_data()['fps']\n        frames = []\n        for img in reader:\n            frames.append(np.array(Image.fromarray(img).resize((size, size))))\n        save_path = \"./demo/tmp/input_motion_sequence.mp4\"\n        imageio.mimwrite(save_path, frames, fps=25)\n        return save_path\n    \n    def read_image(image, size=512):\n        img = np.array(Image.fromarray(image).resize((size, size)))\n        return img\n        \n    # when user uploads a new video\n    motion_sequence.upload(\n        read_video,\n        motion_sequence,\n        motion_sequence\n    )\n    # when `first_frame` is updated\n    reference_image.upload(\n        read_image,\n        reference_image,\n        reference_image\n    )\n    # when the `submit` button is clicked\n    submit.click(\n        animate,\n        [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], \n        animation\n    )\n\n    # Examples\n    gr.Markdown(\"## Examples\")\n    gr.Examples(\n        examples=[\n            [\"inputs/applications/source_image/monalisa.png\", \"inputs/applications/driving/densepose/running.mp4\"], \n            [\"inputs/applications/source_image/demo4.png\", \"inputs/applications/driving/densepose/demo4.mp4\"],\n            [\"inputs/applications/source_image/dalle2.jpeg\", \"inputs/applications/driving/densepose/running2.mp4\"],\n            [\"inputs/applications/source_image/dalle8.jpeg\", \"inputs/applications/driving/densepose/dancing2.mp4\"],\n            [\"inputs/applications/source_image/multi1_source.png\", \"inputs/applications/driving/densepose/multi_dancing.mp4\"],\n        ],\n        inputs=[reference_image, motion_sequence],\n        outputs=animation,\n    )\n\n# demo.queue(max_size=10)\ndemo.launch()\n"
  },
  {
    "path": "environment.yaml",
    "content": "name: manimate\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - asttokens=2.2.1=pyhd8ed1ab_0\n  - backcall=0.2.0=pyh9f0ad1d_0\n  - backports=1.0=pyhd8ed1ab_3\n  - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0\n  - ca-certificates=2023.7.22=hbcca054_0\n  - comm=0.1.4=pyhd8ed1ab_0\n  - debugpy=1.6.7=py38h6a678d5_0\n  - decorator=5.1.1=pyhd8ed1ab_0\n  - entrypoints=0.4=pyhd8ed1ab_0\n  - executing=1.2.0=pyhd8ed1ab_0\n  - ipykernel=6.25.1=pyh71e2992_0\n  - ipython=8.12.0=pyh41d4057_0\n  - jedi=0.19.0=pyhd8ed1ab_0\n  - jupyter_client=7.3.4=pyhd8ed1ab_0\n  - jupyter_core=4.12.0=py38h578d9bd_0\n  - ld_impl_linux-64=2.38=h1181459_1\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libsodium=1.0.18=h36c2ea0_1\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - matplotlib-inline=0.1.6=pyhd8ed1ab_0\n  - ncurses=6.4=h6a678d5_0\n  - nest-asyncio=1.5.6=pyhd8ed1ab_0\n  - openssl=1.1.1l=h7f98852_0\n  - packaging=23.1=pyhd8ed1ab_0\n  - parso=0.8.3=pyhd8ed1ab_0\n  - pexpect=4.8.0=pyh1a96a4e_2\n  - pickleshare=0.7.5=py_1003\n  - pip=23.2.1=py38h06a4308_0\n  - prompt-toolkit=3.0.39=pyha770c72_0\n  - prompt_toolkit=3.0.39=hd8ed1ab_0\n  - ptyprocess=0.7.0=pyhd3deb0d_0\n  - pure_eval=0.2.2=pyhd8ed1ab_0\n  - pygments=2.16.1=pyhd8ed1ab_0\n  - python=3.8.5=h7579374_1\n  - python-dateutil=2.8.2=pyhd8ed1ab_0\n  - python_abi=3.8=2_cp38\n  - pyzmq=25.1.0=py38h6a678d5_0\n  - readline=8.2=h5eee18b_0\n  - setuptools=68.0.0=py38h06a4308_0\n  - six=1.16.0=pyh6c4a22f_0\n  - sqlite=3.41.2=h5eee18b_0\n  - stack_data=0.6.2=pyhd8ed1ab_0\n  - tk=8.6.12=h1ccaba5_0\n  - tornado=6.1=py38h0a891b7_3\n  - traitlets=5.9.0=pyhd8ed1ab_0\n  - typing_extensions=4.7.1=pyha770c72_0\n  - wcwidth=0.2.6=pyhd8ed1ab_0\n  - wheel=0.38.4=py38h06a4308_0\n  - xz=5.4.2=h5eee18b_0\n  - zeromq=4.3.4=h9c3ff4c_1\n  - zlib=1.2.13=h5eee18b_0\n  - pip:\n      - absl-py==1.4.0\n      - accelerate==0.22.0\n      - aiofiles==23.2.1\n      - aiohttp==3.8.5\n      - aiosignal==1.3.1\n      - altair==5.0.1\n      - annotated-types==0.5.0\n      - antlr4-python3-runtime==4.9.3\n      - anyio==3.7.1\n      - async-timeout==4.0.3\n      - attrs==23.1.0\n      - cachetools==5.3.1\n      - certifi==2023.7.22\n      - charset-normalizer==3.2.0\n      - click==8.1.7\n      - cmake==3.27.2\n      - contourpy==1.1.0\n      - cycler==0.11.0\n      - datasets==2.14.4\n      - dill==0.3.7\n      - einops==0.6.1\n      - exceptiongroup==1.1.3\n      - fastapi==0.103.0\n      - ffmpy==0.3.1\n      - filelock==3.12.2\n      - fonttools==4.42.1\n      - frozenlist==1.4.0\n      - fsspec==2023.6.0\n      - google-auth==2.22.0\n      - google-auth-oauthlib==1.0.0\n      - gradio==3.41.2\n      - gradio-client==0.5.0\n      - grpcio==1.57.0\n      - h11==0.14.0\n      - httpcore==0.17.3\n      - httpx==0.24.1\n      - huggingface-hub==0.16.4\n      - idna==3.4\n      - importlib-metadata==6.8.0\n      - importlib-resources==6.0.1\n      - jinja2==3.1.2\n      - joblib==1.3.2\n      - jsonschema==4.19.0\n      - jsonschema-specifications==2023.7.1\n      - kiwisolver==1.4.5\n      - lightning-utilities==0.9.0\n      - lit==16.0.6\n      - markdown==3.4.4\n      - markupsafe==2.1.3\n      - matplotlib==3.7.2\n      - mpmath==1.3.0\n      - multidict==6.0.4\n      - multiprocess==0.70.15\n      - networkx==3.1\n      - numpy==1.24.4\n      - nvidia-cublas-cu11==11.10.3.66\n      - nvidia-cuda-cupti-cu11==11.7.101\n      - nvidia-cuda-nvrtc-cu11==11.7.99\n      - nvidia-cuda-runtime-cu11==11.7.99\n      - nvidia-cudnn-cu11==8.5.0.96\n      - nvidia-cufft-cu11==10.9.0.58\n      - nvidia-curand-cu11==10.2.10.91\n      - nvidia-cusolver-cu11==11.4.0.1\n      - nvidia-cusparse-cu11==11.7.4.91\n      - nvidia-nccl-cu11==2.14.3\n      - nvidia-nvtx-cu11==11.7.91\n      - oauthlib==3.2.2\n      - omegaconf==2.3.0\n      - opencv-python==4.8.0.76\n      - orjson==3.9.5\n      - pandas==2.0.3\n      - pillow==9.5.0\n      - pkgutil-resolve-name==1.3.10\n      - protobuf==4.24.2\n      - psutil==5.9.5\n      - pyarrow==13.0.0\n      - pyasn1==0.5.0\n      - pyasn1-modules==0.3.0\n      - pydantic==2.3.0\n      - pydantic-core==2.6.3\n      - pydub==0.25.1\n      - pyparsing==3.0.9\n      - python-multipart==0.0.6\n      - pytorch-lightning==2.0.7\n      - pytz==2023.3\n      - pyyaml==6.0.1\n      - referencing==0.30.2\n      - regex==2023.8.8\n      - requests==2.31.0\n      - requests-oauthlib==1.3.1\n      - rpds-py==0.9.2\n      - rsa==4.9\n      - safetensors==0.3.3\n      - semantic-version==2.10.0\n      - sniffio==1.3.0\n      - starlette==0.27.0\n      - sympy==1.12\n      - tensorboard==2.14.0\n      - tensorboard-data-server==0.7.1\n      - tokenizers==0.13.3\n      - toolz==0.12.0\n      - torchmetrics==1.1.0\n      - tqdm==4.66.1\n      - transformers==4.32.0\n      - triton==2.0.0\n      - tzdata==2023.3\n      - urllib3==1.26.16\n      - uvicorn==0.23.2\n      - websockets==11.0.3\n      - werkzeug==2.3.7\n      - xxhash==3.3.0\n      - yarl==1.9.2\n      - zipp==3.16.2\n      - decord\n      - imageio==2.9.0\n      - imageio-ffmpeg==0.4.3\n      - timm\n      - scipy\n      - scikit-image\n      - av\n      - imgaug\n      - lpips\n      - ffmpeg-python\n      - torch==2.0.1\n      - torchvision==0.15.2\n      - xformers==0.0.22\n      - diffusers==0.21.4\nprefix: /home/tiger/miniconda3/envs/manimate"
  },
  {
    "path": "magicanimate/models/appearance_encoder.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n)\nfrom diffusers.models.lora import LoRALinearLayer\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    PositionNet,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.unet_2d_blocks import (\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass Identity(torch.nn.Module):\n    r\"\"\"A placeholder identity operator that is argument-insensitive.\n\n    Args:\n        args: any argument (unused)\n        kwargs: any keyword argument (unused)\n\n    Shape:\n        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.\n        - Output: :math:`(*)`, same shape as the input.\n\n    Examples::\n\n        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)\n        >>> input = torch.randn(128, 20)\n        >>> output = m(input)\n        >>> print(output.size())\n        torch.Size([128, 20])\n\n    \"\"\"\n    def __init__(self, scale=None, *args, **kwargs) -> None:\n        super(Identity, self).__init__()\n\n    def forward(self, input, *args, **kwargs):\n        return input\n\n\n\nclass _LoRACompatibleLinear(nn.Module):\n    \"\"\"\n    A Linear layer that can be used with LoRA.\n    \"\"\"\n\n    def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.lora_layer = lora_layer\n\n    def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):\n        self.lora_layer = lora_layer\n\n    def _fuse_lora(self):\n        pass\n\n    def _unfuse_lora(self):\n        pass\n\n    def forward(self, hidden_states, scale=None, lora_scale: int = 1):\n        return hidden_states\n\n\n@dataclass\nclass UNet2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlock2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        attention_type: str = \"default\",\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                attention_type=attention_type,\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()\n        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()\n        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()\n        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])\n        self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()\n        self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None\n        self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()\n        self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()\n        self.up_blocks[3].attentions[2].proj_out = Identity()\n\n        if attention_type in [\"gated\", \"gated-text-image\"]:\n            positive_len = 768\n            if isinstance(cross_attention_dim, int):\n                positive_len = cross_attention_dim\n            elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):\n                positive_len = cross_attention_dim[0]\n\n            feature_type = \"text-only\" if attention_type == \"gated\" else \"text-image\"\n            self.position_net = PositionNet(\n                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type\n            )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnAddedKVProcessor()\n        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 2.5 GLIGEN position net\n        if cross_attention_kwargs is not None and cross_attention_kwargs.get(\"gligen\", None) is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            gligen_args = cross_attention_kwargs.pop(\"gligen\")\n            cross_attention_kwargs[\"gligen\"] = {\"objs\": self.position_net(**gligen_args)}\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n            )\n            # To support T2I-Adapter-XL\n            if (\n                is_adapter\n                and len(down_block_additional_residuals) > 0\n                and sample.shape == down_block_additional_residuals[0].shape\n            ):\n                sample += down_block_additional_residuals.pop(0)\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        if not return_dict:\n            return (sample,)\n\n        return UNet2DConditionOutput(sample=sample)"
  },
  {
    "path": "magicanimate/models/attention.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils import BaseOutput\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.models.attention import FeedForward, AdaLayerNorm\nfrom diffusers.models.attention import Attention as CrossAttention\n\nfrom einops import rearrange, repeat\n\n@dataclass\nclass Transformer3DModelOutput(BaseOutput):\n    sample: torch.FloatTensor\n\n\nif is_xformers_available():\n    import xformers\n    import xformers.ops\nelse:\n    xformers = None\n\n\nclass Transformer3DModel(ModelMixin, ConfigMixin):\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n\n        unet_use_cross_frame_attention=None,\n        unet_use_temporal_attention=None,\n    ):\n        super().__init__()\n        self.use_linear_projection = use_linear_projection\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        inner_dim = num_attention_heads * attention_head_dim\n\n        # Define input layers\n        self.in_channels = in_channels\n\n        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n        if use_linear_projection:\n            self.proj_in = nn.Linear(in_channels, inner_dim)\n        else:\n            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)\n\n        # Define transformers blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    dropout=dropout,\n                    cross_attention_dim=cross_attention_dim,\n                    activation_fn=activation_fn,\n                    num_embeds_ada_norm=num_embeds_ada_norm,\n                    attention_bias=attention_bias,\n                    only_cross_attention=only_cross_attention,\n                    upcast_attention=upcast_attention,\n\n                    unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                    unet_use_temporal_attention=unet_use_temporal_attention,\n                )\n                for d in range(num_layers)\n            ]\n        )\n\n        # 4. Define output layers\n        if use_linear_projection:\n            self.proj_out = nn.Linear(in_channels, inner_dim)\n        else:\n            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):\n        # Input\n        assert hidden_states.dim() == 5, f\"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}.\"\n        video_length = hidden_states.shape[2]\n        hidden_states = rearrange(hidden_states, \"b c f h w -> (b f) c h w\")\n        # JH: need not repeat when a list of prompts are given \n        if encoder_hidden_states.shape[0] != hidden_states.shape[0]:\n            encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)\n\n        batch, channel, height, weight = hidden_states.shape\n        residual = hidden_states\n\n        hidden_states = self.norm(hidden_states)\n        if not self.use_linear_projection:\n            hidden_states = self.proj_in(hidden_states)\n            inner_dim = hidden_states.shape[1]\n            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)\n        else:\n            inner_dim = hidden_states.shape[1]\n            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)\n            hidden_states = self.proj_in(hidden_states)\n\n        # Blocks\n        for block in self.transformer_blocks:\n            hidden_states = block(\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                timestep=timestep,\n                video_length=video_length\n            )\n\n        # Output\n        if not self.use_linear_projection:\n            hidden_states = (\n                hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()\n            )\n            hidden_states = self.proj_out(hidden_states)\n        else:\n            hidden_states = self.proj_out(hidden_states)\n            hidden_states = (\n                hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()\n            )\n\n        output = hidden_states + residual\n\n        output = rearrange(output, \"(b f) c h w -> b c f h w\", f=video_length)\n        if not return_dict:\n            return (output,)\n\n        return Transformer3DModelOutput(sample=output)\n\n\nclass BasicTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout=0.0,\n        cross_attention_dim: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        attention_bias: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n\n        unet_use_cross_frame_attention = None,\n        unet_use_temporal_attention = None,\n    ):\n        super().__init__()\n        self.only_cross_attention = only_cross_attention\n        self.use_ada_layer_norm = num_embeds_ada_norm is not None\n        self.unet_use_cross_frame_attention = unet_use_cross_frame_attention\n        self.unet_use_temporal_attention = unet_use_temporal_attention\n\n        # SC-Attn\n        assert unet_use_cross_frame_attention is not None\n        if unet_use_cross_frame_attention:\n            self.attn1 = SparseCausalAttention2D(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                cross_attention_dim=cross_attention_dim if only_cross_attention else None,\n                upcast_attention=upcast_attention,\n            )\n        else:\n            self.attn1 = CrossAttention(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                upcast_attention=upcast_attention,\n            )\n        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n\n        # Cross-Attn\n        if cross_attention_dim is not None:\n            self.attn2 = CrossAttention(\n                query_dim=dim,\n                cross_attention_dim=cross_attention_dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                upcast_attention=upcast_attention,\n            )\n        else:\n            self.attn2 = None\n\n        if cross_attention_dim is not None:\n            self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n        else:\n            self.norm2 = None\n\n        # Feed-forward\n        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)\n        self.norm3 = nn.LayerNorm(dim)\n        self.use_ada_layer_norm_zero = False\n        \n        # Temp-Attn\n        assert unet_use_temporal_attention is not None\n        if unet_use_temporal_attention:\n            self.attn_temp = CrossAttention(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                upcast_attention=upcast_attention,\n            )\n            nn.init.zeros_(self.attn_temp.to_out[0].weight.data)\n            self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n\n    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):\n        if not is_xformers_available():\n            print(\"Here is how to install it\")\n            raise ModuleNotFoundError(\n                \"Refer to https://github.com/facebookresearch/xformers for more information on how to install\"\n                \" xformers\",\n                name=\"xformers\",\n            )\n        elif not torch.cuda.is_available():\n            raise ValueError(\n                \"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only\"\n                \" available for GPU \"\n            )\n        else:\n            try:\n                # Make sure we can run the memory efficient attention\n                _ = xformers.ops.memory_efficient_attention(\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                )\n            except Exception as e:\n                raise e\n            self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers\n            if self.attn2 is not None:\n                self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers\n            # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers\n\n    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):\n        # SparseCausal-Attention\n        norm_hidden_states = (\n            self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)\n        )\n\n        # if self.only_cross_attention:\n        #     hidden_states = (\n        #         self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states\n        #     )\n        # else:\n        #     hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states\n\n        # pdb.set_trace()\n        if self.unet_use_cross_frame_attention:\n            hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states\n        else:\n            hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states\n\n        if self.attn2 is not None:\n            # Cross-Attention\n            norm_hidden_states = (\n                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n            )\n            hidden_states = (\n                self.attn2(\n                    norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask\n                )\n                + hidden_states\n            )\n\n        # Feed-forward\n        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states\n\n        # Temporal-Attention\n        if self.unet_use_temporal_attention:\n            d = hidden_states.shape[1]\n            hidden_states = rearrange(hidden_states, \"(b f) d c -> (b d) f c\", f=video_length)\n            norm_hidden_states = (\n                self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)\n            )\n            hidden_states = self.attn_temp(norm_hidden_states) + hidden_states\n            hidden_states = rearrange(hidden_states, \"(b d) f c -> (b f) d c\", d=d)\n\n        return hidden_states\n"
  },
  {
    "path": "magicanimate/models/controlnet.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.utils import BaseOutput, logging\nfrom .embeddings import TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.unet_2d_blocks import (\n    CrossAttnDownBlock2D,\n    DownBlock2D,\n    UNetMidBlock2DCrossAttn,\n    get_down_block,\n)\nfrom diffusers.models.unet_2d_condition import UNet2DConditionModel\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass ControlNetOutput(BaseOutput):\n    down_block_res_samples: Tuple[torch.Tensor]\n    mid_block_res_sample: torch.Tensor\n\n\nclass ControlNetConditioningEmbedding(nn.Module):\n    \"\"\"\n    Quoting from https://arxiv.org/abs/2302.05543: \"Stable Diffusion uses a pre-processing method similar to VQ-GAN\n    [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized\n    training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the\n    convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides\n    (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full\n    model) to encode image-space conditions ... into feature maps ...\"\n    \"\"\"\n\n    def __init__(\n        self,\n        conditioning_embedding_channels: int,\n        conditioning_channels: int = 3,\n        block_out_channels: Tuple[int] = (16, 32, 96, 256),\n    ):\n        super().__init__()\n\n        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)\n\n        self.blocks = nn.ModuleList([])\n\n        for i in range(len(block_out_channels) - 1):\n            channel_in = block_out_channels[i]\n            channel_out = block_out_channels[i + 1]\n            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))\n            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))\n\n        self.conv_out = zero_module(\n            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)\n        )\n\n    def forward(self, conditioning):\n        embedding = self.conv_in(conditioning)\n        embedding = F.silu(embedding)\n\n        for block in self.blocks:\n            embedding = block(embedding)\n            embedding = F.silu(embedding)\n\n        embedding = self.conv_out(embedding)\n\n        return embedding\n\n\nclass ControlNetModel(ModelMixin, ConfigMixin):\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        in_channels: int = 4,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),\n    ):\n        super().__init__()\n\n        # Check inputs\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_kernel = 3\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        time_embed_dim = block_out_channels[0] * 4\n\n        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n        timestep_input_dim = block_out_channels[0]\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n        )\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        # control net conditioning embedding\n        self.controlnet_cond_embedding = ControlNetConditioningEmbedding(\n            conditioning_embedding_channels=block_out_channels[0],\n            block_out_channels=conditioning_embedding_out_channels,\n        )\n\n        self.down_blocks = nn.ModuleList([])\n        self.controlnet_down_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        # down\n        output_channel = block_out_channels[0]\n\n        controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_down_blocks.append(controlnet_block)\n\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim,\n                num_attention_heads=attention_head_dim[i],\n                downsample_padding=downsample_padding,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n            )\n            self.down_blocks.append(down_block)\n\n            for _ in range(layers_per_block):\n                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n            if not is_final_block:\n                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n        # mid\n        mid_block_channel = block_out_channels[-1]\n\n        controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_mid_block = controlnet_block\n\n        self.mid_block = UNetMidBlock2DCrossAttn(\n            in_channels=mid_block_channel,\n            temb_channels=time_embed_dim,\n            resnet_eps=norm_eps,\n            resnet_act_fn=act_fn,\n            output_scale_factor=mid_block_scale_factor,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=attention_head_dim[-1],\n            resnet_groups=norm_num_groups,\n            use_linear_projection=use_linear_projection,\n            upcast_attention=upcast_attention,\n        )\n\n    @classmethod\n    def from_unet(\n        cls,\n        unet: UNet2DConditionModel,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),\n        load_weights_from_unet: bool = True,\n    ):\n        r\"\"\"\n        Instantiate Controlnet class from UNet2DConditionModel.\n\n        Parameters:\n            unet (`UNet2DConditionModel`):\n                UNet model which weights are copied to the ControlNet. Note that all configuration options are also\n                copied where applicable.\n        \"\"\"\n        controlnet = cls(\n            in_channels=unet.config.in_channels,\n            flip_sin_to_cos=unet.config.flip_sin_to_cos,\n            freq_shift=unet.config.freq_shift,\n            down_block_types=unet.config.down_block_types,\n            only_cross_attention=unet.config.only_cross_attention,\n            block_out_channels=unet.config.block_out_channels,\n            layers_per_block=unet.config.layers_per_block,\n            downsample_padding=unet.config.downsample_padding,\n            mid_block_scale_factor=unet.config.mid_block_scale_factor,\n            act_fn=unet.config.act_fn,\n            norm_num_groups=unet.config.norm_num_groups,\n            norm_eps=unet.config.norm_eps,\n            cross_attention_dim=unet.config.cross_attention_dim,\n            attention_head_dim=unet.config.attention_head_dim,\n            use_linear_projection=unet.config.use_linear_projection,\n            class_embed_type=unet.config.class_embed_type,\n            num_class_embeds=unet.config.num_class_embeds,\n            upcast_attention=unet.config.upcast_attention,\n            resnet_time_scale_shift=unet.config.resnet_time_scale_shift,\n            projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,\n            controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,\n            conditioning_embedding_out_channels=conditioning_embedding_out_channels,\n        )\n\n        if load_weights_from_unet:\n            controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())\n            controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())\n            controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())\n\n            if controlnet.class_embedding:\n                controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())\n\n            controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())\n            controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())\n\n        return controlnet\n\n    # @property\n    # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors\n    # def attn_processors(self) -> Dict[str, AttentionProcessor]:\n    #     r\"\"\"\n    #     Returns:\n    #         `dict` of attention processors: A dictionary containing all attention processors used in the model with\n    #         indexed by its weight name.\n    #     \"\"\"\n    #     # set recursively\n    #     processors = {}\n\n    #     def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n    #         if hasattr(module, \"set_processor\"):\n    #             processors[f\"{name}.processor\"] = module.processor\n\n    #         for sub_name, child in module.named_children():\n    #             fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n    #         return processors\n\n    #     for name, module in self.named_children():\n    #         fn_recursive_add_processors(name, module, processors)\n\n    #     return processors\n\n    # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n    #     r\"\"\"\n    #     Parameters:\n    #         `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):\n    #             The instantiated processor class or a dictionary of processor classes that will be set as the processor\n    #             of **all** `Attention` layers.\n    #         In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:\n\n    #     \"\"\"\n    #     count = len(self.attn_processors.keys())\n\n    #     if isinstance(processor, dict) and len(processor) != count:\n    #         raise ValueError(\n    #             f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n    #             f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n    #         )\n\n    #     def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n    #         if hasattr(module, \"set_processor\"):\n    #             if not isinstance(processor, dict):\n    #                 module.set_processor(processor)\n    #             else:\n    #                 module.set_processor(processor.pop(f\"{name}.processor\"))\n\n    #         for sub_name, child in module.named_children():\n    #             fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n    #     for name, module in self.named_children():\n    #         fn_recursive_attn_processor(name, module, processor)\n\n    # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor\n    # def set_default_attn_processor(self):\n    #     \"\"\"\n    #     Disables custom attention processors and sets the default attention implementation.\n    #     \"\"\"\n    #     self.set_attn_processor(AttnProcessor())\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module will split the input tensor in slices, to compute attention\n        in several steps. This is useful to save some memory in exchange for a small speed decrease.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, halves the input to the attention heads, so attention will be computed in two steps. If\n                `\"max\"`, maximum amount of memory will be saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        controlnet_cond: torch.FloatTensor,\n        conditioning_scale: float = 1.0,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        return_dict: bool = True,\n    ) -> Union[ControlNetOutput, Tuple]:\n        # check channel order\n        channel_order = self.config.controlnet_conditioning_channel_order\n\n        if channel_order == \"rgb\":\n            # in rgb order by default\n            ...\n        elif channel_order == \"bgr\":\n            controlnet_cond = torch.flip(controlnet_cond, dims=[1])\n        else:\n            raise ValueError(f\"unknown `controlnet_conditioning_channel_order`: {channel_order}\")\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=self.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)\n\n        sample += controlnet_cond\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    # cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                # cross_attention_kwargs=cross_attention_kwargs,\n            )\n\n        # 5. Control net blocks\n\n        controlnet_down_block_res_samples = ()\n\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples += (down_block_res_sample,)\n\n        down_block_res_samples = controlnet_down_block_res_samples\n\n        mid_block_res_sample = self.controlnet_mid_block(sample)\n\n        # 6. scaling\n        down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]\n        mid_block_res_sample *= conditioning_scale\n\n        if not return_dict:\n            return (down_block_res_samples, mid_block_res_sample)\n\n        return ControlNetOutput(\n            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample\n        )\n\n\ndef zero_module(module):\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module"
  },
  {
    "path": "magicanimate/models/embeddings.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\n\ndef get_timestep_embedding(\n    timesteps: torch.Tensor,\n    embedding_dim: int,\n    flip_sin_to_cos: bool = False,\n    downscale_freq_shift: float = 1,\n    scale: float = 1,\n    max_period: int = 10000,\n):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.\n\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the\n    embeddings. :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    assert len(timesteps.shape) == 1, \"Timesteps should be a 1d-array\"\n\n    half_dim = embedding_dim // 2\n    exponent = -math.log(max_period) * torch.arange(\n        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device\n    )\n    exponent = exponent / (half_dim - downscale_freq_shift)\n\n    emb = torch.exp(exponent)\n    emb = timesteps[:, None].float() * emb[None, :]\n\n    # scale embeddings\n    emb = scale * emb\n\n    # concat sine and cosine embeddings\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)\n\n    # flip sine and cosine embeddings\n    if flip_sin_to_cos:\n        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)\n\n    # zero pad\n    if embedding_dim % 2 == 1:\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):\n    \"\"\"\n    grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or\n    [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = np.arange(grid_size, dtype=np.float32)\n    grid_w = np.arange(grid_size, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size, grid_size])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token and extra_tokens > 0:\n        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    if embed_dim % 2 != 0:\n        raise ValueError(\"embed_dim must be divisible by 2\")\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)\n    \"\"\"\n    if embed_dim % 2 != 0:\n        raise ValueError(\"embed_dim must be divisible by 2\")\n\n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"2D Image to Patch Embedding\"\"\"\n\n    def __init__(\n        self,\n        height=224,\n        width=224,\n        patch_size=16,\n        in_channels=3,\n        embed_dim=768,\n        layer_norm=False,\n        flatten=True,\n        bias=True,\n    ):\n        super().__init__()\n\n        num_patches = (height // patch_size) * (width // patch_size)\n        self.flatten = flatten\n        self.layer_norm = layer_norm\n\n        self.proj = nn.Conv2d(\n            in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias\n        )\n        if layer_norm:\n            self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)\n        else:\n            self.norm = None\n\n        pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))\n        self.register_buffer(\"pos_embed\", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)\n\n    def forward(self, latent):\n        latent = self.proj(latent)\n        if self.flatten:\n            latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        if self.layer_norm:\n            latent = self.norm(latent)\n        return latent + self.pos_embed\n\n\nclass TimestepEmbedding(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        time_embed_dim: int,\n        act_fn: str = \"silu\",\n        out_dim: int = None,\n        post_act_fn: Optional[str] = None,\n        cond_proj_dim=None,\n    ):\n        super().__init__()\n\n        self.linear_1 = nn.Linear(in_channels, time_embed_dim)\n\n        if cond_proj_dim is not None:\n            self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)\n        else:\n            self.cond_proj = None\n\n        if act_fn == \"silu\":\n            self.act = nn.SiLU()\n        elif act_fn == \"mish\":\n            self.act = nn.Mish()\n        elif act_fn == \"gelu\":\n            self.act = nn.GELU()\n        else:\n            raise ValueError(f\"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'\")\n\n        if out_dim is not None:\n            time_embed_dim_out = out_dim\n        else:\n            time_embed_dim_out = time_embed_dim\n        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)\n\n        if post_act_fn is None:\n            self.post_act = None\n        elif post_act_fn == \"silu\":\n            self.post_act = nn.SiLU()\n        elif post_act_fn == \"mish\":\n            self.post_act = nn.Mish()\n        elif post_act_fn == \"gelu\":\n            self.post_act = nn.GELU()\n        else:\n            raise ValueError(f\"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'\")\n\n    def forward(self, sample, condition=None):\n        if condition is not None:\n            sample = sample + self.cond_proj(condition)\n        sample = self.linear_1(sample)\n\n        if self.act is not None:\n            sample = self.act(sample)\n\n        sample = self.linear_2(sample)\n\n        if self.post_act is not None:\n            sample = self.post_act(sample)\n        return sample\n\n\nclass Timesteps(nn.Module):\n    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):\n        super().__init__()\n        self.num_channels = num_channels\n        self.flip_sin_to_cos = flip_sin_to_cos\n        self.downscale_freq_shift = downscale_freq_shift\n\n    def forward(self, timesteps):\n        t_emb = get_timestep_embedding(\n            timesteps,\n            self.num_channels,\n            flip_sin_to_cos=self.flip_sin_to_cos,\n            downscale_freq_shift=self.downscale_freq_shift,\n        )\n        return t_emb\n\n\nclass GaussianFourierProjection(nn.Module):\n    \"\"\"Gaussian Fourier embeddings for noise levels.\"\"\"\n\n    def __init__(\n        self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False\n    ):\n        super().__init__()\n        self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)\n        self.log = log\n        self.flip_sin_to_cos = flip_sin_to_cos\n\n        if set_W_to_weight:\n            # to delete later\n            self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)\n\n            self.weight = self.W\n\n    def forward(self, x):\n        if self.log:\n            x = torch.log(x)\n\n        x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi\n\n        if self.flip_sin_to_cos:\n            out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)\n        else:\n            out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)\n        return out\n\n\nclass ImagePositionalEmbeddings(nn.Module):\n    \"\"\"\n    Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the\n    height and width of the latent space.\n\n    For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092\n\n    For VQ-diffusion:\n\n    Output vector embeddings are used as input for the transformer.\n\n    Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.\n\n    Args:\n        num_embed (`int`):\n            Number of embeddings for the latent pixels embeddings.\n        height (`int`):\n            Height of the latent image i.e. the number of height embeddings.\n        width (`int`):\n            Width of the latent image i.e. the number of width embeddings.\n        embed_dim (`int`):\n            Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_embed: int,\n        height: int,\n        width: int,\n        embed_dim: int,\n    ):\n        super().__init__()\n\n        self.height = height\n        self.width = width\n        self.num_embed = num_embed\n        self.embed_dim = embed_dim\n\n        self.emb = nn.Embedding(self.num_embed, embed_dim)\n        self.height_emb = nn.Embedding(self.height, embed_dim)\n        self.width_emb = nn.Embedding(self.width, embed_dim)\n\n    def forward(self, index):\n        emb = self.emb(index)\n\n        height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))\n\n        # 1 x H x D -> 1 x H x 1 x D\n        height_emb = height_emb.unsqueeze(2)\n\n        width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))\n\n        # 1 x W x D -> 1 x 1 x W x D\n        width_emb = width_emb.unsqueeze(1)\n\n        pos_emb = height_emb + width_emb\n\n        # 1 x H x W x D -> 1 x L xD\n        pos_emb = pos_emb.view(1, self.height * self.width, -1)\n\n        emb = emb + pos_emb[:, : emb.shape[1], :]\n\n        return emb\n\n\nclass LabelEmbedding(nn.Module):\n    \"\"\"\n    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.\n\n    Args:\n        num_classes (`int`): The number of classes.\n        hidden_size (`int`): The size of the vector embeddings.\n        dropout_prob (`float`): The probability of dropping a label.\n    \"\"\"\n\n    def __init__(self, num_classes, hidden_size, dropout_prob):\n        super().__init__()\n        use_cfg_embedding = dropout_prob > 0\n        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)\n        self.num_classes = num_classes\n        self.dropout_prob = dropout_prob\n\n    def token_drop(self, labels, force_drop_ids=None):\n        \"\"\"\n        Drops labels to enable classifier-free guidance.\n        \"\"\"\n        if force_drop_ids is None:\n            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob\n        else:\n            drop_ids = torch.tensor(force_drop_ids == 1)\n        labels = torch.where(drop_ids, self.num_classes, labels)\n        return labels\n\n    def forward(self, labels, force_drop_ids=None):\n        use_dropout = self.dropout_prob > 0\n        if (self.training and use_dropout) or (force_drop_ids is not None):\n            labels = self.token_drop(labels, force_drop_ids)\n        embeddings = self.embedding_table(labels)\n        return embeddings\n\n\nclass CombinedTimestepLabelEmbeddings(nn.Module):\n    def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):\n        super().__init__()\n\n        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)\n        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)\n        self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)\n\n    def forward(self, timestep, class_labels, hidden_dtype=None):\n        timesteps_proj = self.time_proj(timestep)\n        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)\n\n        class_labels = self.class_embedder(class_labels)  # (N, D)\n\n        conditioning = timesteps_emb + class_labels  # (N, D)\n\n        return conditioning"
  },
  {
    "path": "magicanimate/models/motion_module.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/guoyww/AnimateDiff\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.utils import BaseOutput\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.models.attention import FeedForward\nfrom magicanimate.models.orig_attention import CrossAttention\n\nfrom einops import rearrange, repeat\nimport math\n\n\ndef zero_module(module):\n    # Zero out the parameters of a module and return it.\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\n@dataclass\nclass TemporalTransformer3DModelOutput(BaseOutput):\n    sample: torch.FloatTensor\n\n\nif is_xformers_available():\n    import xformers\n    import xformers.ops\nelse:\n    xformers = None\n\n\ndef get_motion_module(\n    in_channels,\n    motion_module_type: str, \n    motion_module_kwargs: dict\n):\n    if motion_module_type == \"Vanilla\":\n        return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)    \n    else:\n        raise ValueError\n\n\nclass VanillaTemporalModule(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        num_attention_heads                = 8,\n        num_transformer_block              = 2,\n        attention_block_types              =( \"Temporal_Self\", \"Temporal_Self\" ),\n        cross_frame_attention_mode         = None,\n        temporal_position_encoding         = False,\n        temporal_position_encoding_max_len = 24,\n        temporal_attention_dim_div         = 1,\n        zero_initialize                    = True,\n    ):\n        super().__init__()\n        \n        self.temporal_transformer = TemporalTransformer3DModel(\n            in_channels=in_channels,\n            num_attention_heads=num_attention_heads,\n            attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,\n            num_layers=num_transformer_block,\n            attention_block_types=attention_block_types,\n            cross_frame_attention_mode=cross_frame_attention_mode,\n            temporal_position_encoding=temporal_position_encoding,\n            temporal_position_encoding_max_len=temporal_position_encoding_max_len,\n        )\n        \n        if zero_initialize:\n            self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)\n\n    def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):\n        hidden_states = input_tensor\n        hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)\n\n        output = hidden_states\n        return output\n\n\nclass TemporalTransformer3DModel(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        num_attention_heads,\n        attention_head_dim,\n\n        num_layers,\n        attention_block_types              = ( \"Temporal_Self\", \"Temporal_Self\", ),        \n        dropout                            = 0.0,\n        norm_num_groups                    = 32,\n        cross_attention_dim                = 768,\n        activation_fn                      = \"geglu\",\n        attention_bias                     = False,\n        upcast_attention                   = False,\n        \n        cross_frame_attention_mode         = None,\n        temporal_position_encoding         = False,\n        temporal_position_encoding_max_len = 24,\n    ):\n        super().__init__()\n\n        inner_dim = num_attention_heads * attention_head_dim\n\n        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n        self.proj_in = nn.Linear(in_channels, inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                TemporalTransformerBlock(\n                    dim=inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    attention_block_types=attention_block_types,\n                    dropout=dropout,\n                    norm_num_groups=norm_num_groups,\n                    cross_attention_dim=cross_attention_dim,\n                    activation_fn=activation_fn,\n                    attention_bias=attention_bias,\n                    upcast_attention=upcast_attention,\n                    cross_frame_attention_mode=cross_frame_attention_mode,\n                    temporal_position_encoding=temporal_position_encoding,\n                    temporal_position_encoding_max_len=temporal_position_encoding_max_len,\n                )\n                for d in range(num_layers)\n            ]\n        )\n        self.proj_out = nn.Linear(inner_dim, in_channels)    \n    \n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):\n        assert hidden_states.dim() == 5, f\"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}.\"\n        video_length = hidden_states.shape[2]\n        hidden_states = rearrange(hidden_states, \"b c f h w -> (b f) c h w\")\n\n        batch, channel, height, weight = hidden_states.shape\n        residual = hidden_states\n\n        hidden_states = self.norm(hidden_states)\n        inner_dim = hidden_states.shape[1]\n        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)\n        hidden_states = self.proj_in(hidden_states)\n\n        # Transformer Blocks\n        for block in self.transformer_blocks:\n            hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)\n        \n        # output\n        hidden_states = self.proj_out(hidden_states)\n        hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()\n\n        output = hidden_states + residual\n        output = rearrange(output, \"(b f) c h w -> b c f h w\", f=video_length)\n        \n        return output\n\n\nclass TemporalTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_attention_heads,\n        attention_head_dim,\n        attention_block_types              = ( \"Temporal_Self\", \"Temporal_Self\", ),\n        dropout                            = 0.0,\n        norm_num_groups                    = 32,\n        cross_attention_dim                = 768,\n        activation_fn                      = \"geglu\",\n        attention_bias                     = False,\n        upcast_attention                   = False,\n        cross_frame_attention_mode         = None,\n        temporal_position_encoding         = False,\n        temporal_position_encoding_max_len = 24,\n    ):\n        super().__init__()\n\n        attention_blocks = []\n        norms = []\n        \n        for block_name in attention_block_types:\n            attention_blocks.append(\n                VersatileAttention(\n                    attention_mode=block_name.split(\"_\")[0],\n                    cross_attention_dim=cross_attention_dim if block_name.endswith(\"_Cross\") else None,\n                    \n                    query_dim=dim,\n                    heads=num_attention_heads,\n                    dim_head=attention_head_dim,\n                    dropout=dropout,\n                    bias=attention_bias,\n                    upcast_attention=upcast_attention,\n        \n                    cross_frame_attention_mode=cross_frame_attention_mode,\n                    temporal_position_encoding=temporal_position_encoding,\n                    temporal_position_encoding_max_len=temporal_position_encoding_max_len,\n                )\n            )\n            norms.append(nn.LayerNorm(dim))\n            \n        self.attention_blocks = nn.ModuleList(attention_blocks)\n        self.norms = nn.ModuleList(norms)\n\n        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)\n        self.ff_norm = nn.LayerNorm(dim)\n\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):\n        for attention_block, norm in zip(self.attention_blocks, self.norms):\n            norm_hidden_states = norm(hidden_states)\n            hidden_states = attention_block(\n                norm_hidden_states,\n                encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,\n                video_length=video_length,\n            ) + hidden_states\n            \n        hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states\n        \n        output = hidden_states  \n        return output\n\n\nclass PositionalEncoding(nn.Module):\n    def __init__(\n        self, \n        d_model, \n        dropout = 0., \n        max_len = 24\n    ):\n        super().__init__()\n        self.dropout = nn.Dropout(p=dropout)\n        position = torch.arange(max_len).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n        pe = torch.zeros(1, max_len, d_model)\n        pe[0, :, 0::2] = torch.sin(position * div_term)\n        pe[0, :, 1::2] = torch.cos(position * div_term)\n        self.register_buffer('pe', pe)\n\n    def forward(self, x):\n        x = x + self.pe[:, :x.size(1)]\n        return self.dropout(x)\n\n\nclass VersatileAttention(CrossAttention):\n    def __init__(\n            self,\n            attention_mode                     = None,\n            cross_frame_attention_mode         = None,\n            temporal_position_encoding         = False,\n            temporal_position_encoding_max_len = 24,            \n            *args, **kwargs\n        ):\n        super().__init__(*args, **kwargs)\n        assert attention_mode == \"Temporal\"\n\n        self.attention_mode = attention_mode\n        self.is_cross_attention = kwargs[\"cross_attention_dim\"] is not None\n        \n        self.pos_encoder = PositionalEncoding(\n            kwargs[\"query_dim\"],\n            dropout=0., \n            max_len=temporal_position_encoding_max_len\n        ) if (temporal_position_encoding and attention_mode == \"Temporal\") else None\n\n    def extra_repr(self):\n        return f\"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}\"\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        if self.attention_mode == \"Temporal\":\n            d = hidden_states.shape[1]\n            hidden_states = rearrange(hidden_states, \"(b f) d c -> (b d) f c\", f=video_length)\n            \n            if self.pos_encoder is not None:\n                hidden_states = self.pos_encoder(hidden_states)\n            \n            encoder_hidden_states = repeat(encoder_hidden_states, \"b n c -> (b d) n c\", d=d) if encoder_hidden_states is not None else encoder_hidden_states\n        else:\n            raise NotImplementedError\n\n        encoder_hidden_states = encoder_hidden_states\n\n        if self.group_norm is not None:\n            hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = self.to_q(hidden_states)\n        dim = query.shape[-1]\n        query = self.reshape_heads_to_batch_dim(query)\n\n        if self.added_kv_proj_dim is not None:\n            raise NotImplementedError\n\n        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states\n        key = self.to_k(encoder_hidden_states)\n        value = self.to_v(encoder_hidden_states)\n\n        key = self.reshape_heads_to_batch_dim(key)\n        value = self.reshape_heads_to_batch_dim(value)\n\n        if attention_mask is not None:\n            if attention_mask.shape[-1] != query.shape[1]:\n                target_length = query.shape[1]\n                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n                attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)\n\n        # attention, what we cannot get enough of\n        if self._use_memory_efficient_attention_xformers:\n            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)\n            # Some versions of xformers return output in fp32, cast it back to the dtype of the input\n            hidden_states = hidden_states.to(query.dtype)\n        else:\n            if self._slice_size is None or query.shape[0] // self._slice_size == 1:\n                hidden_states = self._attention(query, key, value, attention_mask)\n            else:\n                hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)\n\n        # linear proj\n        hidden_states = self.to_out[0](hidden_states)\n\n        # dropout\n        hidden_states = self.to_out[1](hidden_states)\n\n        if self.attention_mode == \"Temporal\":\n            hidden_states = rearrange(hidden_states, \"(b d) f c -> (b f) d c\", d=d)\n\n        return hidden_states\n"
  },
  {
    "path": "magicanimate/models/mutual_self_attention.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock\nfrom diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D\nfrom .stable_diffusion_controlnet_reference import torch_dfs\n\n\nclass AttentionBase:\n    def __init__(self):\n        self.cur_step = 0\n        self.num_att_layers = -1\n        self.cur_att_layer = 0\n\n    def after_step(self):\n        pass\n\n    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n        self.cur_att_layer += 1\n        if self.cur_att_layer == self.num_att_layers:\n            self.cur_att_layer = 0\n            self.cur_step += 1\n            # after step\n            self.after_step()\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        out = torch.einsum('b i j, b j d -> b i d', attn, v)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)\n        return out\n\n    def reset(self):\n        self.cur_step = 0\n        self.cur_att_layer = 0\n\n\nclass MutualSelfAttentionControl(AttentionBase):\n\n    def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'):\n        \"\"\"\n        Mutual self-attention control for Stable-Diffusion MODEl\n        Args:\n            total_steps: the total number of steps\n        \"\"\"\n        super().__init__()\n        self.total_steps = total_steps\n        self.hijack = hijack_init_state\n        self.with_negative_guidance = with_negative_guidance\n        \n        # alpha: mutual self attention intensity\n        # TODO: make alpha learnable\n        self.alpha = appearance_control_alpha\n        self.GLOBAL_ATTN_QUEUE = []\n        assert mode in ['enqueue', 'dequeue']\n        MODE = mode\n    \n    def attn_batch(self, q, k, v, num_heads, **kwargs):\n        \"\"\"\n        Performing attention for a batch of queries, keys, and values\n        \"\"\"\n        b = q.shape[0] // num_heads\n        q = rearrange(q, \"(b h) n d -> h (b n) d\", h=num_heads)\n        k = rearrange(k, \"(b h) n d -> h (b n) d\", h=num_heads)\n        v = rearrange(v, \"(b h) n d -> h (b n) d\", h=num_heads)\n\n        sim = torch.einsum(\"h i d, h j d -> h i j\", q, k) * kwargs.get(\"scale\")\n        attn = sim.softmax(-1)\n        out = torch.einsum(\"h i j, h j d -> h i d\", attn, v)\n        out = rearrange(out, \"h (b n) d -> b n (h d)\", b=b)\n        return out\n\n    def mutual_self_attn(self, q, k, v, num_heads, **kwargs):\n        q_tgt, q_src = q.chunk(2)\n        k_tgt, k_src = k.chunk(2)\n        v_tgt, v_src = v.chunk(2)\n        \n        # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \\\n        #           self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha)\n        out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs)\n        out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs)\n        out = torch.cat([out_tgt, out_src], dim=0)\n        return out\n    \n    def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        if self.MODE == 'dequeue' and len(self.kv_queue) > 0:\n            k_src, v_src = self.kv_queue.pop(0)\n            out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs)\n            return out\n        else:\n            self.kv_queue.append([k.clone(), v.clone()])\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n    \n    def get_queue(self):\n        return self.GLOBAL_ATTN_QUEUE\n    \n    def set_queue(self, attn_queue):\n        self.GLOBAL_ATTN_QUEUE = attn_queue\n    \n    def clear_queue(self):\n        self.GLOBAL_ATTN_QUEUE = []\n    \n    def to(self, dtype):\n        self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE]\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n\nclass ReferenceAttentionControl():\n    \n    def __init__(self, \n                 unet,\n                 mode=\"write\",\n                 do_classifier_free_guidance=False,\n                 attention_auto_machine_weight = float('inf'),\n                 gn_auto_machine_weight = 1.0,\n                 style_fidelity = 1.0,\n                 reference_attn=True,\n                 reference_adain=False,\n                 fusion_blocks=\"midup\",\n                 batch_size=1, \n                 ) -> None:\n        # 10. Modify self attention and group norm\n        self.unet = unet\n        assert mode in [\"read\", \"write\"]\n        assert fusion_blocks in [\"midup\", \"full\"]\n        self.reference_attn = reference_attn\n        self.reference_adain = reference_adain\n        self.fusion_blocks = fusion_blocks\n        self.register_reference_hooks(\n            mode, \n            do_classifier_free_guidance,\n            attention_auto_machine_weight,\n            gn_auto_machine_weight,\n            style_fidelity,\n            reference_attn,\n            reference_adain,\n            fusion_blocks,\n            batch_size=batch_size, \n        )\n\n    def register_reference_hooks(\n            self, \n            mode, \n            do_classifier_free_guidance,\n            attention_auto_machine_weight,\n            gn_auto_machine_weight,\n            style_fidelity,\n            reference_attn,\n            reference_adain,\n            dtype=torch.float16,\n            batch_size=1, \n            num_images_per_prompt=1, \n            device=torch.device(\"cpu\"), \n            fusion_blocks='midup',\n        ):\n        MODE = mode\n        do_classifier_free_guidance = do_classifier_free_guidance\n        attention_auto_machine_weight = attention_auto_machine_weight\n        gn_auto_machine_weight = gn_auto_machine_weight\n        style_fidelity = style_fidelity\n        reference_attn = reference_attn\n        reference_adain = reference_adain\n        fusion_blocks = fusion_blocks\n        num_images_per_prompt = num_images_per_prompt\n        dtype=dtype\n        if do_classifier_free_guidance:\n            uc_mask = (\n                torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)\n                .to(device)\n                .bool()\n            )\n        else:\n            uc_mask = (\n                torch.Tensor([0] * batch_size * num_images_per_prompt * 2)\n                .to(device)\n                .bool()\n            )\n        \n        def hacked_basic_transformer_inner_forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.FloatTensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            cross_attention_kwargs: Dict[str, Any] = None,\n            class_labels: Optional[torch.LongTensor] = None,\n            video_length=None,\n        ):\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm1(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero:\n                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n            else:\n                norm_hidden_states = self.norm1(hidden_states)\n\n            # 1. Self-Attention\n            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n            if self.only_cross_attention:\n                attn_output = self.attn1(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                    attention_mask=attention_mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                if MODE == \"write\":\n                    self.bank.append(norm_hidden_states.clone())\n                    attn_output = self.attn1(\n                        norm_hidden_states,\n                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                        attention_mask=attention_mask,\n                        **cross_attention_kwargs,\n                    )\n                if MODE == \"read\":\n                    self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), \"b t l c -> (b t) l c\")[:hidden_states.shape[0]] for d in self.bank]\n                    hidden_states_uc = self.attn1(norm_hidden_states, \n                                                encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),\n                                                attention_mask=attention_mask) + hidden_states\n                    hidden_states_c = hidden_states_uc.clone()\n                    _uc_mask = uc_mask.clone()\n                    if do_classifier_free_guidance:\n                        if hidden_states.shape[0] != _uc_mask.shape[0]:\n                            _uc_mask = (\n                                torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))\n                                .to(device)\n                                .bool()\n                            )\n                        hidden_states_c[_uc_mask] = self.attn1(\n                            norm_hidden_states[_uc_mask],\n                            encoder_hidden_states=norm_hidden_states[_uc_mask],\n                            attention_mask=attention_mask,\n                        ) + hidden_states[_uc_mask]\n                    hidden_states = hidden_states_c.clone()\n                        \n                    self.bank.clear()\n                    if self.attn2 is not None:\n                        # Cross-Attention\n                        norm_hidden_states = (\n                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                        )\n                        hidden_states = (\n                            self.attn2(\n                                norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask\n                            )\n                            + hidden_states\n                        )\n\n                    # Feed-forward\n                    hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states\n\n                    # Temporal-Attention\n                    if self.unet_use_temporal_attention:\n                        d = hidden_states.shape[1]\n                        hidden_states = rearrange(hidden_states, \"(b f) d c -> (b d) f c\", f=video_length)\n                        norm_hidden_states = (\n                            self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)\n                        )\n                        hidden_states = self.attn_temp(norm_hidden_states) + hidden_states\n                        hidden_states = rearrange(hidden_states, \"(b d) f c -> (b f) d c\", d=d)\n\n                    return hidden_states\n                \n            if self.use_ada_layer_norm_zero:\n                attn_output = gate_msa.unsqueeze(1) * attn_output\n            hidden_states = attn_output + hidden_states\n\n            if self.attn2 is not None:\n                norm_hidden_states = (\n                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                )\n\n                # 2. Cross-Attention\n                attn_output = self.attn2(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=encoder_attention_mask,\n                    **cross_attention_kwargs,\n                )\n                hidden_states = attn_output + hidden_states\n\n            # 3. Feed-forward\n            norm_hidden_states = self.norm3(hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n            ff_output = self.ff(norm_hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n            hidden_states = ff_output + hidden_states\n\n            return hidden_states\n\n        def hacked_mid_forward(self, *args, **kwargs):\n            eps = 1e-6\n            x = self.original_forward(*args, **kwargs)\n            if MODE == \"write\":\n                if gn_auto_machine_weight >= self.gn_weight:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    self.mean_bank.append(mean)\n                    self.var_bank.append(var)\n            if MODE == \"read\":\n                if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                    mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))\n                    var_acc = sum(self.var_bank) / float(len(self.var_bank))\n                    std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                    x_uc = (((x - mean) / std) * std_acc) + mean_acc\n                    x_c = x_uc.clone()\n                    if do_classifier_free_guidance and style_fidelity > 0:\n                        x_c[uc_mask] = x[uc_mask]\n                    x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc\n                self.mean_bank = []\n                self.var_bank = []\n            return x\n\n        def hack_CrossAttnDownBlock2D_forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            temb: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.FloatTensor] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        ):\n            eps = 1e-6\n\n            # TODO(Patrick, William) - attention mask is not used\n            output_states = ()\n\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_DownBlock2D_forward(self, hidden_states, temb=None):\n            eps = 1e-6\n\n            output_states = ()\n\n            for i, resnet in enumerate(self.resnets):\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_CrossAttnUpBlock2D_forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n            temb: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.FloatTensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            upsample_size: Optional[int] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        ):\n            eps = 1e-6\n            # TODO(Patrick, William) - attention mask is not used\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):\n            eps = 1e-6\n            for i, resnet in enumerate(self.resnets):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        if self.reference_attn:\n            if self.fusion_blocks == \"midup\":\n                attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]\n            elif self.fusion_blocks == \"full\":\n                attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]            \n            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n\n            for i, module in enumerate(attn_modules):\n                module._original_inner_forward = module.forward\n                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n                module.bank = []\n                module.attn_weight = float(i) / float(len(attn_modules))\n\n        if self.reference_adain:\n            gn_modules = [self.unet.mid_block]\n            self.unet.mid_block.gn_weight = 0\n\n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                module.gn_weight = 1.0 - float(w) / float(len(down_blocks))\n                gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                module.gn_weight = float(w) / float(len(up_blocks))\n                gn_modules.append(module)\n\n            for i, module in enumerate(gn_modules):\n                if getattr(module, \"original_forward\", None) is None:\n                    module.original_forward = module.forward\n                if i == 0:\n                    # mid_block\n                    module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)\n                elif isinstance(module, CrossAttnDownBlock2D):\n                    module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)\n                elif isinstance(module, DownBlock2D):\n                    module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)\n                elif isinstance(module, CrossAttnUpBlock2D):\n                    module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)\n                elif isinstance(module, UpBlock2D):\n                    module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)\n                module.mean_bank = []\n                module.var_bank = []\n                module.gn_weight *= 2\n    \n    def update(self, writer, dtype=torch.float16):\n        if self.reference_attn:\n            if self.fusion_blocks == \"midup\":\n                reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]\n                writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]\n            elif self.fusion_blocks == \"full\":\n                reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)]\n                writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]\n            reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])    \n            writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n            for r, w in zip(reader_attn_modules, writer_attn_modules):\n                r.bank = [v.clone().to(dtype) for v in w.bank]\n                # w.bank.clear()\n        if self.reference_adain:\n            reader_gn_modules = [self.unet.mid_block]\n            \n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                reader_gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                reader_gn_modules.append(module)\n                \n            writer_gn_modules = [writer.unet.mid_block]\n            \n            down_blocks = writer.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                writer_gn_modules.append(module)\n\n            up_blocks = writer.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                writer_gn_modules.append(module)\n            \n            for r, w in zip(reader_gn_modules, writer_gn_modules):\n                if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):\n                    r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]\n                    r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]\n                else:\n                    r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]\n                    r.var_bank = [v.clone().to(dtype) for v in w.var_bank]\n    \n    def clear(self):\n        if self.reference_attn:\n            if self.fusion_blocks == \"midup\":\n                reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]\n            elif self.fusion_blocks == \"full\":\n                reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]\n            reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n            for r in reader_attn_modules:\n                r.bank.clear()\n        if self.reference_adain:\n            reader_gn_modules = [self.unet.mid_block]\n            \n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                reader_gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                reader_gn_modules.append(module)\n            \n            for r in reader_gn_modules:\n                r.mean_bank.clear()\n                r.var_bank.clear()\n            "
  },
  {
    "path": "magicanimate/models/orig_attention.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.embeddings import ImagePositionalEmbeddings\nfrom diffusers.utils import BaseOutput\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n@dataclass\nclass Transformer2DModelOutput(BaseOutput):\n    \"\"\"\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):\n            Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions\n            for the unnoised latent pixels.\n    \"\"\"\n\n    sample: torch.FloatTensor\n\n\nif is_xformers_available():\n    import xformers\n    import xformers.ops\nelse:\n    xformers = None\n\n\nclass Transformer2DModel(ModelMixin, ConfigMixin):\n    \"\"\"\n    Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual\n    embeddings) inputs.\n\n    When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard\n    transformer action. Finally, reshape to image.\n\n    When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional\n    embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict\n    classes of unnoised image.\n\n    Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised\n    image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            Pass if the input is continuous. The number of channels in the input and output.\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.\n        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.\n            Note that this is fixed at training time as it is used for learning a number of position embeddings. See\n            `ImagePositionalEmbeddings`.\n        num_vector_embeds (`int`, *optional*):\n            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.\n            Includes the class for the masked latent pixel.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.\n            The number of diffusion steps used during training. Note that this is fixed at training time as it is used\n            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for\n            up to but not more than steps than `num_embeds_ada_norm`.\n        attention_bias (`bool`, *optional*):\n            Configure if the TransformerBlocks' attention should contain a bias parameter.\n    \"\"\"\n\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        sample_size: Optional[int] = None,\n        num_vector_embeds: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n    ):\n        super().__init__()\n        self.use_linear_projection = use_linear_projection\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        inner_dim = num_attention_heads * attention_head_dim\n\n        # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`\n        # Define whether input is continuous or discrete depending on configuration\n        self.is_input_continuous = in_channels is not None\n        self.is_input_vectorized = num_vector_embeds is not None\n\n        if self.is_input_continuous and self.is_input_vectorized:\n            raise ValueError(\n                f\"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make\"\n                \" sure that either `in_channels` or `num_vector_embeds` is None.\"\n            )\n        elif not self.is_input_continuous and not self.is_input_vectorized:\n            raise ValueError(\n                f\"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make\"\n                \" sure that either `in_channels` or `num_vector_embeds` is not None.\"\n            )\n\n        # 2. Define input layers\n        if self.is_input_continuous:\n            self.in_channels = in_channels\n\n            self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n            if use_linear_projection:\n                self.proj_in = nn.Linear(in_channels, inner_dim)\n            else:\n                self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)\n        elif self.is_input_vectorized:\n            assert sample_size is not None, \"Transformer2DModel over discrete input must provide sample_size\"\n            assert num_vector_embeds is not None, \"Transformer2DModel over discrete input must provide num_embed\"\n\n            self.height = sample_size\n            self.width = sample_size\n            self.num_vector_embeds = num_vector_embeds\n            self.num_latent_pixels = self.height * self.width\n\n            self.latent_image_embedding = ImagePositionalEmbeddings(\n                num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width\n            )\n\n        # 3. Define transformers blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    dropout=dropout,\n                    cross_attention_dim=cross_attention_dim,\n                    activation_fn=activation_fn,\n                    num_embeds_ada_norm=num_embeds_ada_norm,\n                    attention_bias=attention_bias,\n                    only_cross_attention=only_cross_attention,\n                    upcast_attention=upcast_attention,\n                )\n                for d in range(num_layers)\n            ]\n        )\n\n        # 4. Define output layers\n        if self.is_input_continuous:\n            if use_linear_projection:\n                self.proj_out = nn.Linear(in_channels, inner_dim)\n            else:\n                self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)\n        elif self.is_input_vectorized:\n            self.norm_out = nn.LayerNorm(inner_dim)\n            self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)\n\n    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):\n        \"\"\"\n        Args:\n            hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.\n                When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input\n                hidden_states\n            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.long`, *optional*):\n                Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]\n            if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample\n            tensor.\n        \"\"\"\n        # 1. Input\n        if self.is_input_continuous:\n            batch, channel, height, weight = hidden_states.shape\n            residual = hidden_states\n\n            hidden_states = self.norm(hidden_states)\n            if not self.use_linear_projection:\n                hidden_states = self.proj_in(hidden_states)\n                inner_dim = hidden_states.shape[1]\n                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)\n            else:\n                inner_dim = hidden_states.shape[1]\n                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)\n                hidden_states = self.proj_in(hidden_states)\n        elif self.is_input_vectorized:\n            hidden_states = self.latent_image_embedding(hidden_states)\n\n        # 2. Blocks\n        for block in self.transformer_blocks:\n            hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)\n\n        # 3. Output\n        if self.is_input_continuous:\n            if not self.use_linear_projection:\n                hidden_states = (\n                    hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()\n                )\n                hidden_states = self.proj_out(hidden_states)\n            else:\n                hidden_states = self.proj_out(hidden_states)\n                hidden_states = (\n                    hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()\n                )\n\n            output = hidden_states + residual\n        elif self.is_input_vectorized:\n            hidden_states = self.norm_out(hidden_states)\n            logits = self.out(hidden_states)\n            # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)\n            logits = logits.permute(0, 2, 1)\n\n            # log(p(x_0))\n            output = F.log_softmax(logits.double(), dim=1).float()\n\n        if not return_dict:\n            return (output,)\n\n        return Transformer2DModelOutput(sample=output)\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted\n    to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    Uses three q, k, v linear layers to compute attention.\n\n    Parameters:\n        channels (`int`): The number of channels in the input and output.\n        num_head_channels (`int`, *optional*):\n            The number of channels in each head. If None, then `num_heads` = 1.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.\n        rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.\n        eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.\n    \"\"\"\n\n    # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore\n\n    def __init__(\n        self,\n        channels: int,\n        num_head_channels: Optional[int] = None,\n        norm_num_groups: int = 32,\n        rescale_output_factor: float = 1.0,\n        eps: float = 1e-5,\n    ):\n        super().__init__()\n        self.channels = channels\n\n        self.num_heads = channels // num_head_channels if num_head_channels is not None else 1\n        self.num_head_size = num_head_channels\n        self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)\n\n        # define q,k,v as linear layers\n        self.query = nn.Linear(channels, channels)\n        self.key = nn.Linear(channels, channels)\n        self.value = nn.Linear(channels, channels)\n\n        self.rescale_output_factor = rescale_output_factor\n        self.proj_attn = nn.Linear(channels, channels, 1)\n\n        self._use_memory_efficient_attention_xformers = False\n\n    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):\n        if not is_xformers_available():\n            raise ModuleNotFoundError(\n                \"Refer to https://github.com/facebookresearch/xformers for more information on how to install\"\n                \" xformers\",\n                name=\"xformers\",\n            )\n        elif not torch.cuda.is_available():\n            raise ValueError(\n                \"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only\"\n                \" available for GPU \"\n            )\n        else:\n            try:\n                # Make sure we can run the memory efficient attention\n                _ = xformers.ops.memory_efficient_attention(\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                )\n            except Exception as e:\n                raise e\n            self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers\n\n    def reshape_heads_to_batch_dim(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.num_heads\n        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)\n        return tensor\n\n    def reshape_batch_dim_to_heads(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.num_heads\n        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)\n        return tensor\n\n    def forward(self, hidden_states):\n        residual = hidden_states\n        batch, channel, height, width = hidden_states.shape\n\n        # norm\n        hidden_states = self.group_norm(hidden_states)\n\n        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)\n\n        # proj to q, k, v\n        query_proj = self.query(hidden_states)\n        key_proj = self.key(hidden_states)\n        value_proj = self.value(hidden_states)\n\n        scale = 1 / math.sqrt(self.channels / self.num_heads)\n\n        query_proj = self.reshape_heads_to_batch_dim(query_proj)\n        key_proj = self.reshape_heads_to_batch_dim(key_proj)\n        value_proj = self.reshape_heads_to_batch_dim(value_proj)\n\n        if self._use_memory_efficient_attention_xformers:\n            # Memory efficient attention\n            hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)\n            hidden_states = hidden_states.to(query_proj.dtype)\n        else:\n            attention_scores = torch.baddbmm(\n                torch.empty(\n                    query_proj.shape[0],\n                    query_proj.shape[1],\n                    key_proj.shape[1],\n                    dtype=query_proj.dtype,\n                    device=query_proj.device,\n                ),\n                query_proj,\n                key_proj.transpose(-1, -2),\n                beta=0,\n                alpha=scale,\n            )\n            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)\n            hidden_states = torch.bmm(attention_probs, value_proj)\n\n        # reshape hidden_states\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n\n        # compute next hidden_states\n        hidden_states = self.proj_attn(hidden_states)\n\n        hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)\n\n        # res connect and rescale\n        hidden_states = (hidden_states + residual) / self.rescale_output_factor\n        return hidden_states\n\n\nclass BasicTransformerBlock(nn.Module):\n    r\"\"\"\n    A basic Transformer block.\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm (:\n            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.\n        attention_bias (:\n            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout=0.0,\n        cross_attention_dim: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        attention_bias: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n    ):\n        super().__init__()\n        self.only_cross_attention = only_cross_attention\n        self.use_ada_layer_norm = num_embeds_ada_norm is not None\n\n        # 1. Self-Attn\n        self.attn1 = CrossAttention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            cross_attention_dim=cross_attention_dim if only_cross_attention else None,\n            upcast_attention=upcast_attention,\n        )  # is a self-attention\n        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)\n\n        # 2. Cross-Attn\n        if cross_attention_dim is not None:\n            self.attn2 = CrossAttention(\n                query_dim=dim,\n                cross_attention_dim=cross_attention_dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                upcast_attention=upcast_attention,\n            )  # is self-attn if encoder_hidden_states is none\n        else:\n            self.attn2 = None\n\n        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n\n        if cross_attention_dim is not None:\n            self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n        else:\n            self.norm2 = None\n\n        # 3. Feed-forward\n        self.norm3 = nn.LayerNorm(dim)\n\n    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):\n        if not is_xformers_available():\n            print(\"Here is how to install it\")\n            raise ModuleNotFoundError(\n                \"Refer to https://github.com/facebookresearch/xformers for more information on how to install\"\n                \" xformers\",\n                name=\"xformers\",\n            )\n        elif not torch.cuda.is_available():\n            raise ValueError(\n                \"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only\"\n                \" available for GPU \"\n            )\n        else:\n            try:\n                # Make sure we can run the memory efficient attention\n                _ = xformers.ops.memory_efficient_attention(\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                    torch.randn((1, 2, 40), device=\"cuda\"),\n                )\n            except Exception as e:\n                raise e\n            self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers\n            if self.attn2 is not None:\n                self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers\n\n    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):\n        # 1. Self-Attention\n        norm_hidden_states = (\n            self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)\n        )\n\n        if self.only_cross_attention:\n            hidden_states = (\n                self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states\n            )\n        else:\n            hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states\n\n        if self.attn2 is not None:\n            # 2. Cross-Attention\n            norm_hidden_states = (\n                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n            )\n            hidden_states = (\n                self.attn2(\n                    norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask\n                )\n                + hidden_states\n            )\n\n        # 3. Feed-forward\n        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states\n\n        return hidden_states\n\n\nclass CrossAttention(nn.Module):\n    r\"\"\"\n    A cross attention layer.\n\n    Parameters:\n        query_dim (`int`): The number of channels in the query.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.\n        heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.\n        dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        bias (`bool`, *optional*, defaults to False):\n            Set to `True` for the query, key, and value linear layers to contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        query_dim: int,\n        cross_attention_dim: Optional[int] = None,\n        heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias=False,\n        upcast_attention: bool = False,\n        upcast_softmax: bool = False,\n        added_kv_proj_dim: Optional[int] = None,\n        norm_num_groups: Optional[int] = None,\n    ):\n        super().__init__()\n        inner_dim = dim_head * heads\n        cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim\n        self.upcast_attention = upcast_attention\n        self.upcast_softmax = upcast_softmax\n\n        self.scale = dim_head**-0.5\n\n        self.heads = heads\n        # for slice_size > 0 the attention score computation\n        # is split across the batch axis to save memory\n        # You can set slice_size with `set_attention_slice`\n        self.sliceable_head_dim = heads\n        self._slice_size = None\n        self._use_memory_efficient_attention_xformers = False\n        self.added_kv_proj_dim = added_kv_proj_dim\n\n        if norm_num_groups is not None:\n            self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)\n        else:\n            self.group_norm = None\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)\n        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)\n        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)\n\n        if self.added_kv_proj_dim is not None:\n            self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)\n            self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)\n\n        self.to_out = nn.ModuleList([])\n        self.to_out.append(nn.Linear(inner_dim, query_dim))\n        self.to_out.append(nn.Dropout(dropout))\n\n    def reshape_heads_to_batch_dim(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)\n        return tensor\n\n    def reshape_batch_dim_to_heads(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)\n        return tensor\n\n    def set_attention_slice(self, slice_size):\n        if slice_size is not None and slice_size > self.sliceable_head_dim:\n            raise ValueError(f\"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.\")\n\n        self._slice_size = slice_size\n\n    def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        encoder_hidden_states = encoder_hidden_states\n\n        if self.group_norm is not None:\n            hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = self.to_q(hidden_states)\n        dim = query.shape[-1]\n        query = self.reshape_heads_to_batch_dim(query)\n\n        if self.added_kv_proj_dim is not None:\n            key = self.to_k(hidden_states)\n            value = self.to_v(hidden_states)\n            encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)\n            encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)\n\n            key = self.reshape_heads_to_batch_dim(key)\n            value = self.reshape_heads_to_batch_dim(value)\n            encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)\n            encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)\n\n            key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)\n            value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)\n        else:\n            encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states\n            key = self.to_k(encoder_hidden_states)\n            value = self.to_v(encoder_hidden_states)\n\n            key = self.reshape_heads_to_batch_dim(key)\n            value = self.reshape_heads_to_batch_dim(value)\n\n        if attention_mask is not None:\n            if attention_mask.shape[-1] != query.shape[1]:\n                target_length = query.shape[1]\n                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n                attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)\n\n        # attention, what we cannot get enough of\n        if self._use_memory_efficient_attention_xformers:\n            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)\n            # Some versions of xformers return output in fp32, cast it back to the dtype of the input\n            hidden_states = hidden_states.to(query.dtype)\n        else:\n            if self._slice_size is None or query.shape[0] // self._slice_size == 1:\n                hidden_states = self._attention(query, key, value, attention_mask)\n            else:\n                hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)\n\n        # linear proj\n        hidden_states = self.to_out[0](hidden_states)\n\n        # dropout\n        hidden_states = self.to_out[1](hidden_states)\n        return hidden_states\n\n    def _attention(self, query, key, value, attention_mask=None):\n        if self.upcast_attention:\n            query = query.float()\n            key = key.float()\n\n        attention_scores = torch.baddbmm(\n            torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),\n            query,\n            key.transpose(-1, -2),\n            beta=0,\n            alpha=self.scale,\n        )\n\n        if attention_mask is not None:\n            attention_scores = attention_scores + attention_mask\n\n        if self.upcast_softmax:\n            attention_scores = attention_scores.float()\n\n        attention_probs = attention_scores.softmax(dim=-1)\n\n        # cast back to the original dtype\n        attention_probs = attention_probs.to(value.dtype)\n\n        # compute attention output\n        hidden_states = torch.bmm(attention_probs, value)\n\n        # reshape hidden_states\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n        return hidden_states\n\n    def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):\n        batch_size_attention = query.shape[0]\n        hidden_states = torch.zeros(\n            (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype\n        )\n        slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]\n        for i in range(hidden_states.shape[0] // slice_size):\n            start_idx = i * slice_size\n            end_idx = (i + 1) * slice_size\n\n            query_slice = query[start_idx:end_idx]\n            key_slice = key[start_idx:end_idx]\n\n            if self.upcast_attention:\n                query_slice = query_slice.float()\n                key_slice = key_slice.float()\n\n            attn_slice = torch.baddbmm(\n                torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),\n                query_slice,\n                key_slice.transpose(-1, -2),\n                beta=0,\n                alpha=self.scale,\n            )\n\n            if attention_mask is not None:\n                attn_slice = attn_slice + attention_mask[start_idx:end_idx]\n\n            if self.upcast_softmax:\n                attn_slice = attn_slice.float()\n\n            attn_slice = attn_slice.softmax(dim=-1)\n\n            # cast back to the original dtype\n            attn_slice = attn_slice.to(value.dtype)\n            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])\n\n            hidden_states[start_idx:end_idx] = attn_slice\n\n        # reshape hidden_states\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n        return hidden_states\n\n    def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):\n        # TODO attention_mask\n        query = query.contiguous()\n        key = key.contiguous()\n        value = value.contiguous()\n        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n        return hidden_states\n\n\nclass FeedForward(nn.Module):\n    r\"\"\"\n    A feed-forward layer.\n\n    Parameters:\n        dim (`int`): The number of channels in the input.\n        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.\n        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        mult: int = 4,\n        dropout: float = 0.0,\n        activation_fn: str = \"geglu\",\n    ):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = dim_out if dim_out is not None else dim\n\n        if activation_fn == \"gelu\":\n            act_fn = GELU(dim, inner_dim)\n        elif activation_fn == \"geglu\":\n            act_fn = GEGLU(dim, inner_dim)\n        elif activation_fn == \"geglu-approximate\":\n            act_fn = ApproximateGELU(dim, inner_dim)\n\n        self.net = nn.ModuleList([])\n        # project in\n        self.net.append(act_fn)\n        # project dropout\n        self.net.append(nn.Dropout(dropout))\n        # project out\n        self.net.append(nn.Linear(inner_dim, dim_out))\n\n    def forward(self, hidden_states):\n        for module in self.net:\n            hidden_states = module(hidden_states)\n        return hidden_states\n\n\nclass GELU(nn.Module):\n    r\"\"\"\n    GELU activation function\n    \"\"\"\n\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out)\n\n    def gelu(self, gate):\n        if gate.device.type != \"mps\":\n            return F.gelu(gate)\n        # mps: gelu is not implemented for float16\n        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)\n\n    def forward(self, hidden_states):\n        hidden_states = self.proj(hidden_states)\n        hidden_states = self.gelu(hidden_states)\n        return hidden_states\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    r\"\"\"\n    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.\n\n    Parameters:\n        dim_in (`int`): The number of channels in the input.\n        dim_out (`int`): The number of channels in the output.\n    \"\"\"\n\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def gelu(self, gate):\n        if gate.device.type != \"mps\":\n            return F.gelu(gate)\n        # mps: gelu is not implemented for float16\n        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)\n\n    def forward(self, hidden_states):\n        hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)\n        return hidden_states * self.gelu(gate)\n\n\nclass ApproximateGELU(nn.Module):\n    \"\"\"\n    The approximate form of Gaussian Error Linear Unit (GELU)\n\n    For more details, see section 2: https://arxiv.org/abs/1606.08415\n    \"\"\"\n\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out)\n\n    def forward(self, x):\n        x = self.proj(x)\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass AdaLayerNorm(nn.Module):\n    \"\"\"\n    Norm layer modified to incorporate timestep embeddings.\n    \"\"\"\n\n    def __init__(self, embedding_dim, num_embeddings):\n        super().__init__()\n        self.emb = nn.Embedding(num_embeddings, embedding_dim)\n        self.silu = nn.SiLU()\n        self.linear = nn.Linear(embedding_dim, embedding_dim * 2)\n        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)\n\n    def forward(self, x, timestep):\n        emb = self.linear(self.silu(self.emb(timestep)))\n        scale, shift = torch.chunk(emb, 2)\n        x = self.norm(x) * (1 + scale) + shift\n        return x\n\n\nclass DualTransformer2DModel(nn.Module):\n    \"\"\"\n    Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            Pass if the input is continuous. The number of channels in the input and output.\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.\n        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.\n            Note that this is fixed at training time as it is used for learning a number of position embeddings. See\n            `ImagePositionalEmbeddings`.\n        num_vector_embeds (`int`, *optional*):\n            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.\n            Includes the class for the masked latent pixel.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.\n            The number of diffusion steps used during training. Note that this is fixed at training time as it is used\n            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for\n            up to but not more than steps than `num_embeds_ada_norm`.\n        attention_bias (`bool`, *optional*):\n            Configure if the TransformerBlocks' attention should contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        sample_size: Optional[int] = None,\n        num_vector_embeds: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n    ):\n        super().__init__()\n        self.transformers = nn.ModuleList(\n            [\n                Transformer2DModel(\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    in_channels=in_channels,\n                    num_layers=num_layers,\n                    dropout=dropout,\n                    norm_num_groups=norm_num_groups,\n                    cross_attention_dim=cross_attention_dim,\n                    attention_bias=attention_bias,\n                    sample_size=sample_size,\n                    num_vector_embeds=num_vector_embeds,\n                    activation_fn=activation_fn,\n                    num_embeds_ada_norm=num_embeds_ada_norm,\n                )\n                for _ in range(2)\n            ]\n        )\n\n        # Variables that can be set by a pipeline:\n\n        # The ratio of transformer1 to transformer2's output states to be combined during inference\n        self.mix_ratio = 0.5\n\n        # The shape of `encoder_hidden_states` is expected to be\n        # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`\n        self.condition_lengths = [77, 257]\n\n        # Which transformer to use to encode which condition.\n        # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`\n        self.transformer_index_for_condition = [1, 0]\n\n    def forward(\n        self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True\n    ):\n        \"\"\"\n        Args:\n            hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.\n                When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input\n                hidden_states\n            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.long`, *optional*):\n                Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.\n            attention_mask (`torch.FloatTensor`, *optional*):\n                Optional attention mask to be applied in CrossAttention\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]\n            if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample\n            tensor.\n        \"\"\"\n        input_states = hidden_states\n\n        encoded_states = []\n        tokens_start = 0\n        # attention_mask is not used yet\n        for i in range(2):\n            # for each of the two transformers, pass the corresponding condition tokens\n            condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]\n            transformer_index = self.transformer_index_for_condition[i]\n            encoded_state = self.transformers[transformer_index](\n                input_states,\n                encoder_hidden_states=condition_state,\n                timestep=timestep,\n                return_dict=False,\n            )[0]\n            encoded_states.append(encoded_state - input_states)\n            tokens_start += self.condition_lengths[i]\n\n        output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)\n        output_states = output_states + input_states\n\n        if not return_dict:\n            return (output_states,)\n\n        return Transformer2DModelOutput(sample=output_states)"
  },
  {
    "path": "magicanimate/models/resnet.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/guoyww/AnimateDiff\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange\n\n\nclass InflatedConv3d(nn.Conv2d):\n    def forward(self, x):\n        video_length = x.shape[2]\n\n        x = rearrange(x, \"b c f h w -> (b f) c h w\")\n        x = super().forward(x)\n        x = rearrange(x, \"(b f) c h w -> b c f h w\", f=video_length)\n\n        return x\n\n\nclass Upsample3D(nn.Module):\n    def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name=\"conv\"):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_conv_transpose = use_conv_transpose\n        self.name = name\n\n        conv = None\n        if use_conv_transpose:\n            raise NotImplementedError\n        elif use_conv:\n            self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)\n\n    def forward(self, hidden_states, output_size=None):\n        assert hidden_states.shape[1] == self.channels\n\n        if self.use_conv_transpose:\n            raise NotImplementedError\n\n        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16\n        dtype = hidden_states.dtype\n        if dtype == torch.bfloat16:\n            hidden_states = hidden_states.to(torch.float32)\n\n        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984\n        if hidden_states.shape[0] >= 64:\n            hidden_states = hidden_states.contiguous()\n\n        # if `output_size` is passed we force the interpolation output\n        # size and do not make use of `scale_factor=2`\n        if output_size is None:\n            hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode=\"nearest\")\n        else:\n            hidden_states = F.interpolate(hidden_states, size=output_size, mode=\"nearest\")\n\n        # If the input is bfloat16, we cast back to bfloat16\n        if dtype == torch.bfloat16:\n            hidden_states = hidden_states.to(dtype)\n\n        hidden_states = self.conv(hidden_states)\n\n        return hidden_states\n\n\nclass Downsample3D(nn.Module):\n    def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name=\"conv\"):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.padding = padding\n        stride = 2\n        self.name = name\n\n        if use_conv:\n            self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)\n        else:\n            raise NotImplementedError\n\n    def forward(self, hidden_states):\n        assert hidden_states.shape[1] == self.channels\n        if self.use_conv and self.padding == 0:\n            raise NotImplementedError\n\n        assert hidden_states.shape[1] == self.channels\n        hidden_states = self.conv(hidden_states)\n\n        return hidden_states\n\n\nclass ResnetBlock3D(nn.Module):\n    def __init__(\n        self,\n        *,\n        in_channels,\n        out_channels=None,\n        conv_shortcut=False,\n        dropout=0.0,\n        temb_channels=512,\n        groups=32,\n        groups_out=None,\n        pre_norm=True,\n        eps=1e-6,\n        non_linearity=\"swish\",\n        time_embedding_norm=\"default\",\n        output_scale_factor=1.0,\n        use_in_shortcut=None,\n    ):\n        super().__init__()\n        self.pre_norm = pre_norm\n        self.pre_norm = True\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n        self.time_embedding_norm = time_embedding_norm\n        self.output_scale_factor = output_scale_factor\n\n        if groups_out is None:\n            groups_out = groups\n\n        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n\n        self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        if temb_channels is not None:\n            if self.time_embedding_norm == \"default\":\n                time_emb_proj_out_channels = out_channels\n            elif self.time_embedding_norm == \"scale_shift\":\n                time_emb_proj_out_channels = out_channels * 2\n            else:\n                raise ValueError(f\"unknown time_embedding_norm : {self.time_embedding_norm} \")\n\n            self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)\n        else:\n            self.time_emb_proj = None\n\n        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        if non_linearity == \"swish\":\n            self.nonlinearity = lambda x: F.silu(x)\n        elif non_linearity == \"mish\":\n            self.nonlinearity = Mish()\n        elif non_linearity == \"silu\":\n            self.nonlinearity = nn.SiLU()\n\n        self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut\n\n        self.conv_shortcut = None\n        if self.use_in_shortcut:\n            self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, input_tensor, temb):\n        hidden_states = input_tensor\n\n        hidden_states = self.norm1(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n\n        hidden_states = self.conv1(hidden_states)\n\n        if temb is not None:\n            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]\n\n        if temb is not None and self.time_embedding_norm == \"default\":\n            hidden_states = hidden_states + temb\n\n        hidden_states = self.norm2(hidden_states)\n\n        if temb is not None and self.time_embedding_norm == \"scale_shift\":\n            scale, shift = torch.chunk(temb, 2, dim=1)\n            hidden_states = hidden_states * (1 + scale) + shift\n\n        hidden_states = self.nonlinearity(hidden_states)\n\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            input_tensor = self.conv_shortcut(input_tensor)\n\n        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor\n\n        return output_tensor\n\n\nclass Mish(torch.nn.Module):\n    def forward(self, hidden_states):\n        return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))"
  },
  {
    "path": "magicanimate/models/stable_diffusion_controlnet_reference.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\n\nfrom diffusers import StableDiffusionControlNetPipeline\nfrom diffusers.models import ControlNetModel\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import cv2\n        >>> import torch\n        >>> import numpy as np\n        >>> from PIL import Image\n        >>> from diffusers import UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> input_image = load_image(\"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\")\n\n        >>> # get canny image\n        >>> image = cv2.Canny(np.array(input_image), 100, 200)\n        >>> image = image[:, :, None]\n        >>> image = np.concatenate([image, image, image], axis=2)\n        >>> canny_image = Image.fromarray(image)\n\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(\n                \"runwayml/stable-diffusion-v1-5\",\n                controlnet=controlnet,\n                safety_checker=None,\n                torch_dtype=torch.float16\n                ).to('cuda:0')\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)\n\n        >>> result_img = pipe(ref_image=input_image,\n                        prompt=\"1girl\",\n                        image=canny_image,\n                        num_inference_steps=20,\n                        reference_attn=True,\n                        reference_adain=True).images[0]\n\n        >>> result_img.show()\n        ```\n\"\"\"\n\n\ndef torch_dfs(model: torch.nn.Module):\n    result = [model]\n    for child in model.children():\n        result += torch_dfs(child)\n    return result\n\n\nclass StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):\n    def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):\n        refimage = refimage.to(device=device, dtype=dtype)\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            ref_image_latents = [\n                self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            ref_image_latents = torch.cat(ref_image_latents, dim=0)\n        else:\n            ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)\n        ref_image_latents = self.vae.config.scaling_factor * ref_image_latents\n\n        # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method\n        if ref_image_latents.shape[0] < batch_size:\n            if not batch_size % ref_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)\n\n        ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)\n        return ref_image_latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[\n            torch.FloatTensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.FloatTensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        attention_auto_machine_weight: float = 1.0,\n        gn_auto_machine_weight: float = 1.0,\n        style_fidelity: float = 0.5,\n        reference_attn: bool = True,\n        reference_adain: bool = True,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can\n                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If\n                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are\n                specified in init, images must be passed as a list such that each element of the list can be correctly\n                batched for input to a single controlnet.\n            ref_image (`torch.FloatTensor`, `PIL.Image.Image`):\n                The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If\n                the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can\n                also be accepted as an image.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will ge generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if\n                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.\n            attention_auto_machine_weight (`float`):\n                Weight of using reference query for self attention's context.\n                If attention_auto_machine_weight=1.0, use reference query for all self attention's context.\n            gn_auto_machine_weight (`float`):\n                Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.\n            style_fidelity (`float`):\n                style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,\n                elif style_fidelity=0.0, prompt more important, else balanced.\n            reference_attn (`bool`):\n                Whether to use reference query for self attention's context.\n            reference_adain (`bool`):\n                Whether to use reference adain.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        assert reference_attn or reference_adain, \"`reference_attn` or `reference_adain` must be True.\"\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Preprocess reference image\n        ref_image = self.prepare_image(\n            image=ref_image,\n            width=width,\n            height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            dtype=prompt_embeds.dtype,\n        )\n\n        # 6. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 7. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 8. Prepare reference latent variables\n        ref_image_latents = self.prepare_ref_latents(\n            ref_image,\n            batch_size * num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 10. Modify self attention and group norm\n        MODE = \"write\"\n        uc_mask = (\n            torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)\n            .type_as(ref_image_latents)\n            .bool()\n        )\n\n        def hacked_basic_transformer_inner_forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.FloatTensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            cross_attention_kwargs: Dict[str, Any] = None,\n            class_labels: Optional[torch.LongTensor] = None,\n        ):\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm1(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero:\n                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n            else:\n                norm_hidden_states = self.norm1(hidden_states)\n\n            # 1. Self-Attention\n            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n            if self.only_cross_attention:\n                attn_output = self.attn1(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                    attention_mask=attention_mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                if MODE == \"write\":\n                    self.bank.append(norm_hidden_states.detach().clone())\n                    attn_output = self.attn1(\n                        norm_hidden_states,\n                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                        attention_mask=attention_mask,\n                        **cross_attention_kwargs,\n                    )\n                if MODE == \"read\":\n                    if attention_auto_machine_weight > self.attn_weight:\n                        attn_output_uc = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),\n                            # attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n                        attn_output_c = attn_output_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            attn_output_c[uc_mask] = self.attn1(\n                                norm_hidden_states[uc_mask],\n                                encoder_hidden_states=norm_hidden_states[uc_mask],\n                                **cross_attention_kwargs,\n                            )\n                        attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc\n                        self.bank.clear()\n                    else:\n                        attn_output = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                            attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n            if self.use_ada_layer_norm_zero:\n                attn_output = gate_msa.unsqueeze(1) * attn_output\n            hidden_states = attn_output + hidden_states\n\n            if self.attn2 is not None:\n                norm_hidden_states = (\n                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                )\n\n                # 2. Cross-Attention\n                attn_output = self.attn2(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=encoder_attention_mask,\n                    **cross_attention_kwargs,\n                )\n                hidden_states = attn_output + hidden_states\n\n            # 3. Feed-forward\n            norm_hidden_states = self.norm3(hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n            ff_output = self.ff(norm_hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n            hidden_states = ff_output + hidden_states\n\n            return hidden_states\n\n        def hacked_mid_forward(self, *args, **kwargs):\n            eps = 1e-6\n            x = self.original_forward(*args, **kwargs)\n            if MODE == \"write\":\n                if gn_auto_machine_weight >= self.gn_weight:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    self.mean_bank.append(mean)\n                    self.var_bank.append(var)\n            if MODE == \"read\":\n                if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                    mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))\n                    var_acc = sum(self.var_bank) / float(len(self.var_bank))\n                    std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                    x_uc = (((x - mean) / std) * std_acc) + mean_acc\n                    x_c = x_uc.clone()\n                    if do_classifier_free_guidance and style_fidelity > 0:\n                        x_c[uc_mask] = x[uc_mask]\n                    x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc\n                self.mean_bank = []\n                self.var_bank = []\n            return x\n\n        def hack_CrossAttnDownBlock2D_forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            temb: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.FloatTensor] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        ):\n            eps = 1e-6\n\n            # TODO(Patrick, William) - attention mask is not used\n            output_states = ()\n\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_DownBlock2D_forward(self, hidden_states, temb=None):\n            eps = 1e-6\n\n            output_states = ()\n\n            for i, resnet in enumerate(self.resnets):\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_CrossAttnUpBlock2D_forward(\n            self,\n            hidden_states: torch.FloatTensor,\n            res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n            temb: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.FloatTensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            upsample_size: Optional[int] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        ):\n            eps = 1e-6\n            # TODO(Patrick, William) - attention mask is not used\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):\n            eps = 1e-6\n            for i, resnet in enumerate(self.resnets):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        if reference_attn:\n            attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]\n            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n\n            for i, module in enumerate(attn_modules):\n                module._original_inner_forward = module.forward\n                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n                module.bank = []\n                module.attn_weight = float(i) / float(len(attn_modules))\n\n        if reference_adain:\n            gn_modules = [self.unet.mid_block]\n            self.unet.mid_block.gn_weight = 0\n\n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                module.gn_weight = 1.0 - float(w) / float(len(down_blocks))\n                gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                module.gn_weight = float(w) / float(len(up_blocks))\n                gn_modules.append(module)\n\n            for i, module in enumerate(gn_modules):\n                if getattr(module, \"original_forward\", None) is None:\n                    module.original_forward = module.forward\n                if i == 0:\n                    # mid_block\n                    module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)\n                elif isinstance(module, CrossAttnDownBlock2D):\n                    module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)\n                elif isinstance(module, DownBlock2D):\n                    module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)\n                elif isinstance(module, CrossAttnUpBlock2D):\n                    module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)\n                elif isinstance(module, UpBlock2D):\n                    module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)\n                module.mean_bank = []\n                module.var_bank = []\n                module.gn_weight *= 2\n\n        # 11. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=image,\n                    conditioning_scale=controlnet_conditioning_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and do_classifier_free_guidance:\n                    # Infered ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # ref only part\n                noise = randn_tensor(\n                    ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype\n                )\n                ref_xt = self.scheduler.add_noise(\n                    ref_image_latents,\n                    noise,\n                    t.reshape(\n                        1,\n                    ),\n                )\n                ref_xt = self.scheduler.scale_model_input(ref_xt, t)\n\n                MODE = \"write\"\n                self.unet(\n                    ref_xt,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )\n\n                # predict the noise residual\n                MODE = \"read\"\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "magicanimate/models/unet.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/guoyww/AnimateDiff\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport os\nimport json\nimport pdb\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom .unet_3d_blocks import (\n    CrossAttnDownBlock3D,\n    CrossAttnUpBlock3D,\n    DownBlock3D,\n    UNetMidBlock3DCrossAttn,\n    UpBlock3D,\n    get_down_block,\n    get_up_block,\n)\nfrom .resnet import InflatedConv3d\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNet3DConditionOutput(BaseOutput):\n    sample: torch.FloatTensor\n\n\nclass UNet3DConditionModel(ModelMixin, ConfigMixin):\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,      \n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"DownBlock3D\",\n        ),\n        mid_block_type: str = \"UNetMidBlock3DCrossAttn\",\n        up_block_types: Tuple[str] = (\n            \"UpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\"\n        ),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        \n        # Additional\n        use_motion_module              = False,\n        motion_module_resolutions      = ( 1,2,4,8 ),\n        motion_module_mid_block        = False,\n        motion_module_decoder_only     = False,\n        motion_module_type             = None,\n        motion_module_kwargs           = {},\n        unet_use_cross_frame_attention = None,\n        unet_use_temporal_attention    = None,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n        time_embed_dim = block_out_channels[0] * 4\n\n        # input\n        self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))\n\n        # time\n        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n        timestep_input_dim = block_out_channels[0]\n\n        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        self.down_blocks = nn.ModuleList([])\n        self.mid_block = None\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            res = 2 ** i\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim,\n                attn_num_head_channels=attention_head_dim[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n\n                unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                unet_use_temporal_attention=unet_use_temporal_attention,\n                \n                use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),\n                motion_module_type=motion_module_type,\n                motion_module_kwargs=motion_module_kwargs,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock3DCrossAttn\":\n            self.mid_block = UNetMidBlock3DCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim,\n                attn_num_head_channels=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n\n                unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                unet_use_temporal_attention=unet_use_temporal_attention,\n                \n                use_motion_module=use_motion_module and motion_module_mid_block,\n                motion_module_type=motion_module_type,\n                motion_module_kwargs=motion_module_kwargs,\n            )\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n        \n        # count how many layers upsample the videos\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_attention_head_dim = list(reversed(attention_head_dim))\n        only_cross_attention = list(reversed(only_cross_attention))\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            res = 2 ** (3 - i)\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=layers_per_block + 1,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim,\n                attn_num_head_channels=reversed_attention_head_dim[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n\n                unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                unet_use_temporal_attention=unet_use_temporal_attention,\n\n                use_motion_module=use_motion_module and (res in motion_module_resolutions),\n                motion_module_type=motion_module_type,\n                motion_module_kwargs=motion_module_kwargs,\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)\n        self.conv_act = nn.SiLU()\n        self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module will split the input tensor in slices, to compute attention\n        in several steps. This is useful to save some memory in exchange for a small speed decrease.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, halves the input to the attention heads, so attention will be computed in two steps. If\n                `\"max\"`, maxium amount of memory will be saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_slicable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_slicable_dims(module)\n\n        num_slicable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_slicable_layers * [1]\n\n        slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet3DConditionOutput, Tuple]:\n        r\"\"\"\n        Args:\n            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor\n            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps\n            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When\n            returning a tuple, the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=self.dtype)\n        emb = self.time_embedding(t_emb)\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        # pre-process\n        sample = self.conv_in(sample)\n\n        # down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)\n\n            down_block_res_samples += res_samples\n\n        # mid\n        sample = self.mid_block(\n            sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask\n        )\n\n        # up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,\n                )\n\n        # post-process\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNet3DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):\n        if subfolder is not None:\n            pretrained_model_path = os.path.join(pretrained_model_path, subfolder)\n        print(f\"loaded temporal unet's pretrained weights from {pretrained_model_path} ...\")\n\n        config_file = os.path.join(pretrained_model_path, 'config.json')\n        if not os.path.isfile(config_file):\n            raise RuntimeError(f\"{config_file} does not exist\")\n        with open(config_file, \"r\") as f:\n            config = json.load(f)\n        config[\"_class_name\"] = cls.__name__\n        config[\"down_block_types\"] = [\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"DownBlock3D\"\n        ]\n        config[\"up_block_types\"] = [\n            \"UpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\"\n        ]\n\n        from diffusers.utils import WEIGHTS_NAME\n        model = cls.from_config(config, **unet_additional_kwargs)\n        model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)\n        if not os.path.isfile(model_file):\n            raise RuntimeError(f\"{model_file} does not exist\")\n        state_dict = torch.load(model_file, map_location=\"cpu\")\n\n        m, u = model.load_state_dict(state_dict, strict=False)\n        print(f\"### missing keys: {len(m)}; \\n### unexpected keys: {len(u)};\")\n        # print(f\"### missing keys:\\n{m}\\n### unexpected keys:\\n{u}\\n\")\n        \n        params = [p.numel() if \"temporal\" in n else 0 for n, p in model.named_parameters()]\n        print(f\"### Temporal Module Parameters: {sum(params) / 1e6} M\")\n        \n        return model\n"
  },
  {
    "path": "magicanimate/models/unet_3d_blocks.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/guoyww/AnimateDiff\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom torch import nn\n\nfrom .attention import Transformer3DModel\nfrom .resnet import Downsample3D, ResnetBlock3D, Upsample3D\nfrom .motion_module import get_motion_module\n\n\ndef get_down_block(\n    down_block_type,\n    num_layers,\n    in_channels,\n    out_channels,\n    temb_channels,\n    add_downsample,\n    resnet_eps,\n    resnet_act_fn,\n    attn_num_head_channels,\n    resnet_groups=None,\n    cross_attention_dim=None,\n    downsample_padding=None,\n    dual_cross_attention=False,\n    use_linear_projection=False,\n    only_cross_attention=False,\n    upcast_attention=False,\n    resnet_time_scale_shift=\"default\",\n    \n    unet_use_cross_frame_attention=None,\n    unet_use_temporal_attention=None,\n    \n    use_motion_module=None,\n    \n    motion_module_type=None,\n    motion_module_kwargs=None,\n):\n    down_block_type = down_block_type[7:] if down_block_type.startswith(\"UNetRes\") else down_block_type\n    if down_block_type == \"DownBlock3D\":\n        return DownBlock3D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n\n            use_motion_module=use_motion_module,\n            motion_module_type=motion_module_type,\n            motion_module_kwargs=motion_module_kwargs,\n        )\n    elif down_block_type == \"CrossAttnDownBlock3D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnDownBlock3D\")\n        return CrossAttnDownBlock3D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            cross_attention_dim=cross_attention_dim,\n            attn_num_head_channels=attn_num_head_channels,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n\n            unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n            unet_use_temporal_attention=unet_use_temporal_attention,\n            \n            use_motion_module=use_motion_module,\n            motion_module_type=motion_module_type,\n            motion_module_kwargs=motion_module_kwargs,\n        )\n    raise ValueError(f\"{down_block_type} does not exist.\")\n\n\ndef get_up_block(\n    up_block_type,\n    num_layers,\n    in_channels,\n    out_channels,\n    prev_output_channel,\n    temb_channels,\n    add_upsample,\n    resnet_eps,\n    resnet_act_fn,\n    attn_num_head_channels,\n    resnet_groups=None,\n    cross_attention_dim=None,\n    dual_cross_attention=False,\n    use_linear_projection=False,\n    only_cross_attention=False,\n    upcast_attention=False,\n    resnet_time_scale_shift=\"default\",\n\n    unet_use_cross_frame_attention=None,\n    unet_use_temporal_attention=None,\n    \n    use_motion_module=None,\n    motion_module_type=None,\n    motion_module_kwargs=None,\n):\n    up_block_type = up_block_type[7:] if up_block_type.startswith(\"UNetRes\") else up_block_type\n    if up_block_type == \"UpBlock3D\":\n        return UpBlock3D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n\n            use_motion_module=use_motion_module,\n            motion_module_type=motion_module_type,\n            motion_module_kwargs=motion_module_kwargs,\n        )\n    elif up_block_type == \"CrossAttnUpBlock3D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnUpBlock3D\")\n        return CrossAttnUpBlock3D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attn_num_head_channels=attn_num_head_channels,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n\n            unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n            unet_use_temporal_attention=unet_use_temporal_attention,\n\n            use_motion_module=use_motion_module,\n            motion_module_type=motion_module_type,\n            motion_module_kwargs=motion_module_kwargs,\n        )\n    raise ValueError(f\"{up_block_type} does not exist.\")\n\n\nclass UNetMidBlock3DCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attn_num_head_channels=1,\n        output_scale_factor=1.0,\n        cross_attention_dim=1280,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        upcast_attention=False,\n\n        unet_use_cross_frame_attention=None,\n        unet_use_temporal_attention=None,\n\n        use_motion_module=None,\n        \n        motion_module_type=None,\n        motion_module_kwargs=None,\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n        self.attn_num_head_channels = attn_num_head_channels\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock3D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n        motion_modules = []\n\n        for _ in range(num_layers):\n            if dual_cross_attention:\n                raise NotImplementedError\n            attentions.append(\n                Transformer3DModel(\n                    attn_num_head_channels,\n                    in_channels // attn_num_head_channels,\n                    in_channels=in_channels,\n                    num_layers=1,\n                    cross_attention_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    use_linear_projection=use_linear_projection,\n                    upcast_attention=upcast_attention,\n\n                    unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                    unet_use_temporal_attention=unet_use_temporal_attention,\n                )\n            )\n            motion_modules.append(\n                get_motion_module(\n                    in_channels=in_channels,\n                    motion_module_type=motion_module_type, \n                    motion_module_kwargs=motion_module_kwargs,\n                ) if use_motion_module else None\n            )\n            resnets.append(\n                ResnetBlock3D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n        self.motion_modules = nn.ModuleList(motion_modules)\n\n    def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):\n            hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample\n            hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states\n            hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass CrossAttnDownBlock3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attn_num_head_channels=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        downsample_padding=1,\n        add_downsample=True,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        only_cross_attention=False,\n        upcast_attention=False,\n\n        unet_use_cross_frame_attention=None,\n        unet_use_temporal_attention=None,\n        \n        use_motion_module=None,\n\n        motion_module_type=None,\n        motion_module_kwargs=None,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n        motion_modules = []\n\n        self.has_cross_attention = True\n        self.attn_num_head_channels = attn_num_head_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock3D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if dual_cross_attention:\n                raise NotImplementedError\n            attentions.append(\n                Transformer3DModel(\n                    attn_num_head_channels,\n                    out_channels // attn_num_head_channels,\n                    in_channels=out_channels,\n                    num_layers=1,\n                    cross_attention_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    use_linear_projection=use_linear_projection,\n                    only_cross_attention=only_cross_attention,\n                    upcast_attention=upcast_attention,\n\n                    unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                    unet_use_temporal_attention=unet_use_temporal_attention,\n                )\n            )\n            motion_modules.append(\n                get_motion_module(\n                    in_channels=out_channels,\n                    motion_module_type=motion_module_type, \n                    motion_module_kwargs=motion_module_kwargs,\n                ) if use_motion_module else None\n            )\n            \n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n        self.motion_modules = nn.ModuleList(motion_modules)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample3D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):\n        output_states = ()\n\n        for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(attn, return_dict=False),\n                    hidden_states,\n                    encoder_hidden_states,\n                )[0]\n                if motion_module is not None:\n                    hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)\n                \n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample\n                \n                # add motion module\n                hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states\n\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass DownBlock3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_downsample=True,\n        downsample_padding=1,\n        \n        use_motion_module=None,\n        motion_module_type=None,\n        motion_module_kwargs=None,\n    ):\n        super().__init__()\n        resnets = []\n        motion_modules = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock3D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            motion_modules.append(\n                get_motion_module(\n                    in_channels=out_channels,\n                    motion_module_type=motion_module_type, \n                    motion_module_kwargs=motion_module_kwargs,\n                ) if use_motion_module else None\n            )\n            \n        self.resnets = nn.ModuleList(resnets)\n        self.motion_modules = nn.ModuleList(motion_modules)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample3D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, temb=None, encoder_hidden_states=None):\n        output_states = ()\n\n        for resnet, motion_module in zip(self.resnets, self.motion_modules):\n            if self.training and self.gradient_checkpointing:\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                if motion_module is not None:\n                    hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n                # add motion module\n                hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states\n\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass CrossAttnUpBlock3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attn_num_head_channels=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        only_cross_attention=False,\n        upcast_attention=False,\n\n        unet_use_cross_frame_attention=None,\n        unet_use_temporal_attention=None,\n        \n        use_motion_module=None,\n\n        motion_module_type=None,\n        motion_module_kwargs=None,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n        motion_modules = []\n\n        self.has_cross_attention = True\n        self.attn_num_head_channels = attn_num_head_channels\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock3D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if dual_cross_attention:\n                raise NotImplementedError\n            attentions.append(\n                Transformer3DModel(\n                    attn_num_head_channels,\n                    out_channels // attn_num_head_channels,\n                    in_channels=out_channels,\n                    num_layers=1,\n                    cross_attention_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    use_linear_projection=use_linear_projection,\n                    only_cross_attention=only_cross_attention,\n                    upcast_attention=upcast_attention,\n\n                    unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                    unet_use_temporal_attention=unet_use_temporal_attention,\n                )\n            )\n            motion_modules.append(\n                get_motion_module(\n                    in_channels=out_channels,\n                    motion_module_type=motion_module_type, \n                    motion_module_kwargs=motion_module_kwargs,\n                ) if use_motion_module else None\n            )\n            \n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n        self.motion_modules = nn.ModuleList(motion_modules)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states,\n        res_hidden_states_tuple,\n        temb=None,\n        encoder_hidden_states=None,\n        upsample_size=None,\n        attention_mask=None,\n    ):\n        for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(attn, return_dict=False),\n                    hidden_states,\n                    encoder_hidden_states,\n                )[0]\n                if motion_module is not None:\n                    hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)\n            \n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample\n                \n                # add motion module\n                hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size)\n\n        return hidden_states\n\n\nclass UpBlock3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_upsample=True,\n\n        use_motion_module=None,\n        motion_module_type=None,\n        motion_module_kwargs=None,\n    ):\n        super().__init__()\n        resnets = []\n        motion_modules = []\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock3D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            motion_modules.append(\n                get_motion_module(\n                    in_channels=out_channels,\n                    motion_module_type=motion_module_type, \n                    motion_module_kwargs=motion_module_kwargs,\n                ) if use_motion_module else None\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.motion_modules = nn.ModuleList(motion_modules)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):\n        for resnet, motion_module in zip(self.resnets, self.motion_modules):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                if motion_module is not None:\n                    hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size)\n\n        return hidden_states"
  },
  {
    "path": "magicanimate/models/unet_controlnet.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport os\nimport json\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom magicanimate.models.unet_3d_blocks import (\n    CrossAttnDownBlock3D,\n    CrossAttnUpBlock3D,\n    DownBlock3D,\n    UNetMidBlock3DCrossAttn,\n    UpBlock3D,\n    get_down_block,\n    get_up_block,\n)\nfrom .resnet import InflatedConv3d\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNet3DConditionOutput(BaseOutput):\n    sample: torch.FloatTensor\n\n\nclass UNet3DConditionModel(ModelMixin, ConfigMixin):\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,      \n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"DownBlock3D\",\n        ),\n        mid_block_type: str = \"UNetMidBlock3DCrossAttn\",\n        up_block_types: Tuple[str] = (\n            \"UpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\"\n        ),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        \n        # Additional\n        use_motion_module              = False,\n        motion_module_resolutions      = ( 1,2,4,8 ),\n        motion_module_mid_block        = False,\n        motion_module_decoder_only     = False,\n        motion_module_type             = None,\n        motion_module_kwargs           = {},\n        unet_use_cross_frame_attention = None,\n        unet_use_temporal_attention    = None,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n        time_embed_dim = block_out_channels[0] * 4\n\n        # input\n        self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))\n\n        # time\n        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n        timestep_input_dim = block_out_channels[0]\n\n        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        self.down_blocks = nn.ModuleList([])\n        self.mid_block = None\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            res = 2 ** i\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim,\n                attn_num_head_channels=attention_head_dim[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n\n                unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                unet_use_temporal_attention=unet_use_temporal_attention,\n                \n                use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),\n                motion_module_type=motion_module_type,\n                motion_module_kwargs=motion_module_kwargs,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock3DCrossAttn\":\n            self.mid_block = UNetMidBlock3DCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim,\n                attn_num_head_channels=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n\n                unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                unet_use_temporal_attention=unet_use_temporal_attention,\n                \n                use_motion_module=use_motion_module and motion_module_mid_block,\n                motion_module_type=motion_module_type,\n                motion_module_kwargs=motion_module_kwargs,\n            )\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n        \n        # count how many layers upsample the videos\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_attention_head_dim = list(reversed(attention_head_dim))\n        only_cross_attention = list(reversed(only_cross_attention))\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            res = 2 ** (3 - i)\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=layers_per_block + 1,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim,\n                attn_num_head_channels=reversed_attention_head_dim[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n\n                unet_use_cross_frame_attention=unet_use_cross_frame_attention,\n                unet_use_temporal_attention=unet_use_temporal_attention,\n\n                use_motion_module=use_motion_module and (res in motion_module_resolutions),\n                motion_module_type=motion_module_type,\n                motion_module_kwargs=motion_module_kwargs,\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)\n        self.conv_act = nn.SiLU()\n        self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module will split the input tensor in slices, to compute attention\n        in several steps. This is useful to save some memory in exchange for a small speed decrease.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, halves the input to the attention heads, so attention will be computed in two steps. If\n                `\"max\"`, maxium amount of memory will be saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_slicable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_slicable_dims(module)\n\n        num_slicable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_slicable_layers * [1]\n\n        slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        # for controlnet\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet3DConditionOutput, Tuple]:\n        r\"\"\"\n        Args:\n            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor\n            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps\n            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When\n            returning a tuple, the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=self.dtype)\n        emb = self.time_embedding(t_emb)\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        # pre-process\n        sample = self.conv_in(sample)\n\n        # down\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # mid\n        sample = self.mid_block(\n            sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask\n        )\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,\n                )\n\n        # post-process\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNet3DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):\n        if subfolder is not None:\n            pretrained_model_path = os.path.join(pretrained_model_path, subfolder)\n        print(f\"loaded temporal unet's pretrained weights from {pretrained_model_path} ...\")\n\n        config_file = os.path.join(pretrained_model_path, 'config.json')\n        if not os.path.isfile(config_file):\n            raise RuntimeError(f\"{config_file} does not exist\")\n        with open(config_file, \"r\") as f:\n            config = json.load(f)\n        config[\"_class_name\"] = cls.__name__\n        config[\"down_block_types\"] = [\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"CrossAttnDownBlock3D\",\n            \"DownBlock3D\"\n        ]\n        config[\"up_block_types\"] = [\n            \"UpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\",\n            \"CrossAttnUpBlock3D\"\n        ]\n        # config[\"mid_block_type\"] = \"UNetMidBlock3DCrossAttn\"\n\n        from diffusers.utils import WEIGHTS_NAME\n        model = cls.from_config(config, **unet_additional_kwargs)\n        model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)\n        if not os.path.isfile(model_file):\n            raise RuntimeError(f\"{model_file} does not exist\")\n        state_dict = torch.load(model_file, map_location=\"cpu\")\n\n        m, u = model.load_state_dict(state_dict, strict=False)\n        print(f\"### missing keys: {len(m)}; \\n### unexpected keys: {len(u)};\")\n        # print(f\"### missing keys:\\n{m}\\n### unexpected keys:\\n{u}\\n\")\n        \n        params = [p.numel() if \"temporal\" in n else 0 for n, p in model.named_parameters()]\n        print(f\"### Temporal Module Parameters: {sum(params) / 1e6} M\")\n        \n        return model\n"
  },
  {
    "path": "magicanimate/pipelines/animation.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport argparse\nimport datetime\nimport inspect\nimport os\nimport random\nimport numpy as np\n\nfrom PIL import Image\nfrom omegaconf import OmegaConf\nfrom collections import OrderedDict\n\nimport torch\nimport torch.distributed as dist\n\nfrom diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler\n\nfrom tqdm import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom magicanimate.models.unet_controlnet import UNet3DConditionModel\nfrom magicanimate.models.controlnet import ControlNetModel\nfrom magicanimate.models.appearance_encoder import AppearanceEncoderModel\nfrom magicanimate.models.mutual_self_attention import ReferenceAttentionControl\nfrom magicanimate.pipelines.pipeline_animation import AnimationPipeline\nfrom magicanimate.utils.util import save_videos_grid\nfrom magicanimate.utils.dist_tools import distributed_init\nfrom accelerate.utils import set_seed\n\nfrom magicanimate.utils.videoreader import VideoReader\n\nfrom einops import rearrange\n\nfrom pathlib import Path\n\n\ndef main(args):\n\n    *_, func_args = inspect.getargvalues(inspect.currentframe())\n    func_args = dict(func_args)\n    \n    config  = OmegaConf.load(args.config)\n      \n    # Initialize distributed training\n    device = torch.device(f\"cuda:{args.rank}\")\n    dist_kwargs = {\"rank\":args.rank, \"world_size\":args.world_size, \"dist\":args.dist}\n    \n    if config.savename is None:\n        time_str = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n        savedir = f\"samples/{Path(args.config).stem}-{time_str}\"\n    else:\n        savedir = f\"samples/{config.savename}\"\n        \n    if args.dist:\n        dist.broadcast_object_list([savedir], 0)\n        dist.barrier()\n    \n    if args.rank == 0:\n        os.makedirs(savedir, exist_ok=True)\n\n    inference_config = OmegaConf.load(config.inference_config)\n        \n    motion_module = config.motion_module\n    \n    ### >>> create animation pipeline >>> ###\n    tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder=\"text_encoder\")\n    if config.pretrained_unet_path:\n        unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))\n    else:\n        unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder=\"unet\", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))\n    appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder=\"appearance_encoder\").to(device)\n    reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)\n    reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)\n    if config.pretrained_vae_path is not None:\n        vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)\n    else:\n        vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder=\"vae\")\n\n    ### Load controlnet\n    controlnet   = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)\n\n    unet.enable_xformers_memory_efficient_attention()\n    appearance_encoder.enable_xformers_memory_efficient_attention()\n    controlnet.enable_xformers_memory_efficient_attention()\n\n    vae.to(torch.float16)\n    unet.to(torch.float16)\n    text_encoder.to(torch.float16)\n    appearance_encoder.to(torch.float16)\n    controlnet.to(torch.float16)\n\n    pipeline = AnimationPipeline(\n        vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,\n        scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),\n        # NOTE: UniPCMultistepScheduler\n    )\n\n    # 1. unet ckpt\n    # 1.1 motion module\n    motion_module_state_dict = torch.load(motion_module, map_location=\"cpu\")\n    if \"global_step\" in motion_module_state_dict: func_args.update({\"global_step\": motion_module_state_dict[\"global_step\"]})\n    motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict\n    try:\n        # extra steps for self-trained models\n        state_dict = OrderedDict()\n        for key in motion_module_state_dict.keys():\n            if key.startswith(\"module.\"):\n                _key = key.split(\"module.\")[-1]\n                state_dict[_key] = motion_module_state_dict[key]\n            else:\n                state_dict[key] = motion_module_state_dict[key]\n        motion_module_state_dict = state_dict\n        del state_dict\n        missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)\n        assert len(unexpected) == 0\n    except:\n        _tmp_ = OrderedDict()\n        for key in motion_module_state_dict.keys():\n            if \"motion_modules\" in key:\n                if key.startswith(\"unet.\"):\n                    _key = key.split('unet.')[-1]\n                    _tmp_[_key] = motion_module_state_dict[key]\n                else:\n                    _tmp_[key] = motion_module_state_dict[key]\n        missing, unexpected = unet.load_state_dict(_tmp_, strict=False)\n        assert len(unexpected) == 0\n        del _tmp_\n    del motion_module_state_dict\n\n    pipeline.to(device)\n    ### <<< create validation pipeline <<< ###\n    \n    random_seeds = config.get(\"seed\", [-1])\n    random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)\n    random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds\n    \n    # input test videos (either source video/ conditions)\n    \n    test_videos = config.video_path\n    source_images = config.source_image\n    num_actual_inference_steps = config.get(\"num_actual_inference_steps\", config.steps)\n\n    # read size, step from yaml file\n    sizes = [config.size] * len(test_videos)\n    steps = [config.S] * len(test_videos)\n\n    config.random_seed = []\n    prompt = n_prompt = \"\"\n    for idx, (source_image, test_video, random_seed, size, step) in tqdm(\n        enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), \n        total=len(test_videos), \n        disable=(args.rank!=0)\n    ):\n        samples_per_video = []\n        samples_per_clip = []\n        # manually set random seed for reproduction\n        if random_seed != -1: \n            torch.manual_seed(random_seed)\n            set_seed(random_seed)\n        else:\n            torch.seed()\n        config.random_seed.append(torch.initial_seed())\n\n        if test_video.endswith('.mp4'):\n            control = VideoReader(test_video).read()\n            if control[0].shape[0] != size:\n                control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]\n            if config.max_length is not None:\n                control = control[config.offset: (config.offset+config.max_length)]\n            control = np.array(control)\n        \n        if source_image.endswith(\".mp4\"):\n            source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size)))\n        else:\n            source_image = np.array(Image.open(source_image).resize((size, size)))\n        H, W, C = source_image.shape\n        \n        print(f\"current seed: {torch.initial_seed()}\")\n        init_latents = None\n        \n        # print(f\"sampling {prompt} ...\")\n        original_length = control.shape[0]\n        if control.shape[0] % config.L > 0:\n            control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge')\n        generator = torch.Generator(device=torch.device(\"cuda:0\"))\n        generator.manual_seed(torch.initial_seed())\n        sample = pipeline(\n            prompt,\n            negative_prompt         = n_prompt,\n            num_inference_steps     = config.steps,\n            guidance_scale          = config.guidance_scale,\n            width                   = W,\n            height                  = H,\n            video_length            = len(control),\n            controlnet_condition    = control,\n            init_latents            = init_latents,\n            generator               = generator,\n            num_actual_inference_steps = num_actual_inference_steps,\n            appearance_encoder       = appearance_encoder, \n            reference_control_writer = reference_control_writer,\n            reference_control_reader = reference_control_reader,\n            source_image             = source_image,\n            **dist_kwargs,\n        ).videos\n\n        if args.rank == 0:\n            source_images = np.array([source_image] * original_length)\n            source_images = rearrange(torch.from_numpy(source_images), \"t h w c -> 1 c t h w\") / 255.0\n            samples_per_video.append(source_images)\n            \n            control = control / 255.0\n            control = rearrange(control, \"t h w c -> 1 c t h w\")\n            control = torch.from_numpy(control)\n            samples_per_video.append(control[:, :, :original_length])\n\n            samples_per_video.append(sample[:, :, :original_length])\n                \n            samples_per_video = torch.cat(samples_per_video)\n\n            video_name = os.path.basename(test_video)[:-4]\n            source_name = os.path.basename(config.source_image[idx]).split(\".\")[0]\n            save_videos_grid(samples_per_video[-1:], f\"{savedir}/videos/{source_name}_{video_name}.mp4\")\n            save_videos_grid(samples_per_video, f\"{savedir}/videos/{source_name}_{video_name}/grid.mp4\")\n\n            if config.save_individual_videos:\n                save_videos_grid(samples_per_video[1:2], f\"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4\")\n                save_videos_grid(samples_per_video[0:1], f\"{savedir}/videos/{source_name}_{video_name}/orig.mp4\")\n                \n        if args.dist:\n            dist.barrier()\n               \n    if args.rank == 0:\n        OmegaConf.save(config, f\"{savedir}/config.yaml\")\n\n\ndef distributed_main(device_id, args):\n    args.rank = device_id\n    args.device_id = device_id\n    if torch.cuda.is_available():\n        torch.cuda.set_device(args.device_id)\n        torch.cuda.init()\n    distributed_init(args)\n    main(args)\n\n\ndef run(args):\n\n    if args.dist:\n        args.world_size = max(1, torch.cuda.device_count())\n        assert args.world_size <= torch.cuda.device_count()\n\n        if args.world_size > 0 and torch.cuda.device_count() > 1:\n            port = random.randint(10000, 20000)\n            args.init_method = f\"tcp://localhost:{port}\"\n            torch.multiprocessing.spawn(\n                fn=distributed_main,\n                args=(args,),\n                nprocs=args.world_size,\n            )\n    else:\n        main(args)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", type=str, required=True)\n    parser.add_argument(\"--dist\", action=\"store_true\", required=False)\n    parser.add_argument(\"--rank\", type=int, default=0, required=False)\n    parser.add_argument(\"--world_size\", type=int, default=1, required=False)\n\n    args = parser.parse_args()\n    run(args)\n"
  },
  {
    "path": "magicanimate/pipelines/context.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main\nimport numpy as np\nfrom typing import Callable, Optional, List\n\n\ndef ordered_halving(val):\n    bin_str = f\"{val:064b}\"\n    bin_flip = bin_str[::-1]\n    as_int = int(bin_flip, 2)\n\n    return as_int / (1 << 64)\n\n\ndef uniform(\n    step: int = ...,\n    num_steps: Optional[int] = None,\n    num_frames: int = ...,\n    context_size: Optional[int] = None,\n    context_stride: int = 3,\n    context_overlap: int = 4,\n    closed_loop: bool = True,\n):\n    if num_frames <= context_size:\n        yield list(range(num_frames))\n        return\n\n    context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)\n\n    for context_step in 1 << np.arange(context_stride):\n        pad = int(round(num_frames * ordered_halving(step)))\n        for j in range(\n            int(ordered_halving(step) * context_step) + pad,\n            num_frames + pad + (0 if closed_loop else -context_overlap),\n            (context_size * context_step - context_overlap),\n        ):\n            yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]\n\n\ndef get_context_scheduler(name: str) -> Callable:\n    if name == \"uniform\":\n        return uniform\n    else:\n        raise ValueError(f\"Unknown context_overlap policy {name}\")\n\n\ndef get_total_steps(\n    scheduler,\n    timesteps: List[int],\n    num_steps: Optional[int] = None,\n    num_frames: int = ...,\n    context_size: Optional[int] = None,\n    context_stride: int = 3,\n    context_overlap: int = 4,\n    closed_loop: bool = True,\n):\n    return sum(\n        len(\n            list(\n                scheduler(\n                    i,\n                    num_steps,\n                    num_frames,\n                    context_size,\n                    context_stride,\n                    context_overlap,\n                )\n            )\n        )\n        for i in range(len(timesteps))\n    )\n"
  },
  {
    "path": "magicanimate/pipelines/pipeline_animation.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py\n\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTODO:\n1. support multi-controlnet\n2. [DONE] support DDIM inversion\n3. support Prompt-to-prompt\n\"\"\"\n\nimport inspect, math\nfrom typing import Callable, List, Optional, Union\nfrom dataclasses import dataclass\nfrom PIL import Image\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom tqdm import tqdm\nfrom diffusers.utils import is_accelerate_available\nfrom packaging import version\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL\nfrom diffusers.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n)\nfrom diffusers.utils import deprecate, logging, BaseOutput\n\nfrom einops import rearrange\n\nfrom magicanimate.models.unet_controlnet import UNet3DConditionModel\nfrom magicanimate.models.controlnet import ControlNetModel\nfrom magicanimate.models.mutual_self_attention import ReferenceAttentionControl\nfrom magicanimate.pipelines.context import (\n    get_context_scheduler,\n    get_total_steps\n)\nfrom magicanimate.utils.util import get_tensor_interpolation_method\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass AnimationPipelineOutput(BaseOutput):\n    videos: Union[torch.Tensor, np.ndarray]\n\n\nclass AnimationPipeline(DiffusionPipeline):\n    _optional_components = []\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet3DConditionModel,\n        controlnet: ControlNetModel,\n        scheduler: Union[\n            DDIMScheduler,\n            PNDMScheduler,\n            LMSDiscreteScheduler,\n            EulerDiscreteScheduler,\n            EulerAncestralDiscreteScheduler,\n            DPMSolverMultistepScheduler,\n        ],\n    ):\n        super().__init__()\n\n        if hasattr(scheduler.config, \"steps_offset\") and scheduler.config.steps_offset != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if hasattr(scheduler.config, \"clip_sample\") and scheduler.config.clip_sample is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        is_unet_version_less_0_9_0 = hasattr(unet.config, \"_diffusers_version\") and version.parse(\n            version.parse(unet.config._diffusers_version).base_version\n        ) < version.parse(\"0.9.0.dev0\")\n        is_unet_sample_size_less_64 = hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- runwayml/stable-diffusion-v1-5\"\n                \" \\n- runwayml/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n\n    def enable_vae_slicing(self):\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        self.vae.disable_slicing()\n\n    def enable_sequential_cpu_offload(self, gpu_id=0):\n        if is_accelerate_available():\n            from accelerate import cpu_offload\n        else:\n            raise ImportError(\"Please install accelerate via `pip install accelerate`\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:\n            if cpu_offloaded_model is not None:\n                cpu_offload(cpu_offloaded_model, device)\n\n\n    @property\n    def _execution_device(self):\n        if self.device != torch.device(\"meta\") or not hasattr(self.unet, \"_hf_hook\"):\n            return self.device\n        for module in self.unet.modules():\n            if (\n                hasattr(module, \"_hf_hook\")\n                and hasattr(module._hf_hook, \"execution_device\")\n                and module._hf_hook.execution_device is not None\n            ):\n                return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):\n        batch_size = len(prompt) if isinstance(prompt, list) else 1\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n\n        if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n            attention_mask = text_inputs.attention_mask.to(device)\n        else:\n            attention_mask = None\n\n        text_embeddings = self.text_encoder(\n            text_input_ids.to(device),\n            attention_mask=attention_mask,\n        )\n        text_embeddings = text_embeddings[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            uncond_embeddings = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            uncond_embeddings = uncond_embeddings[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        return text_embeddings\n\n    def decode_latents(self, latents, rank, decoder_consistency=None):\n        video_length = latents.shape[2]\n        latents = 1 / 0.18215 * latents\n        latents = rearrange(latents, \"b c f h w -> (b f) c h w\")\n        # video = self.vae.decode(latents).sample\n        video = []\n        for frame_idx in tqdm(range(latents.shape[0]), disable=(rank!=0)):\n            if decoder_consistency is not None:\n                video.append(decoder_consistency(latents[frame_idx:frame_idx+1]))\n            else:\n                video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)\n        video = torch.cat(video)\n        video = rearrange(video, \"(b f) c h w -> b c f h w\", f=video_length)\n        video = (video / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16\n        video = video.cpu().float().numpy()\n        return video\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(self, prompt, height, width, callback_steps):\n        if not isinstance(prompt, str) and not isinstance(prompt, list):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, clip_length=16):\n        shape = (batch_size, num_channels_latents, clip_length, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n        if latents is None:\n            rand_device = \"cpu\" if device.type == \"mps\" else device\n\n            if isinstance(generator, list):\n                latents = [\n                    torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)\n                    for i in range(batch_size)\n                ]\n                latents = torch.cat(latents, dim=0).to(device)\n            else:\n                latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)\n                \n            latents = latents.repeat(1, 1, video_length//clip_length, 1, 1)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def prepare_condition(self, condition, num_videos_per_prompt, device, dtype, do_classifier_free_guidance):\n        # prepare conditions for controlnet\n        condition = torch.from_numpy(condition.copy()).to(device=device, dtype=dtype) / 255.0\n        condition = torch.stack([condition for _ in range(num_videos_per_prompt)], dim=0)\n        condition = rearrange(condition, 'b f h w c -> (b f) c h w').clone()\n        if do_classifier_free_guidance:\n            condition = torch.cat([condition] * 2)\n        return condition\n\n    def next_step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta=0.,\n        verbose=False\n    ):\n        \"\"\"\n        Inverse sampling for DDIM Inversion\n        \"\"\"\n        if verbose:\n            print(\"timestep: \", timestep)\n        next_step = timestep\n        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod\n        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output\n        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir\n        return x_next, pred_x0\n\n    @torch.no_grad()\n    def images2latents(self, images, dtype):\n        \"\"\"\n        Convert RGB image to VAE latents\n        \"\"\"\n        device = self._execution_device\n        images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1\n        images = rearrange(images, \"f h w c -> f c h w\").to(device)\n        latents = []\n        for frame_idx in range(images.shape[0]):\n            latents.append(self.vae.encode(images[frame_idx:frame_idx+1])['latent_dist'].mean * 0.18215)\n        latents = torch.cat(latents)\n        return latents\n\n    @torch.no_grad()\n    def invert(\n        self,\n        image: torch.Tensor,\n        prompt,\n        num_inference_steps=20,\n        num_actual_inference_steps=10,\n        eta=0.0,\n        return_intermediates=False,\n        **kwargs):\n        \"\"\"\n        Adapted from: https://github.com/Yujun-Shi/DragDiffusion/blob/main/drag_pipeline.py#L440\n        invert a real image into noise map with determinisc DDIM inversion\n        \"\"\"\n        device = self._execution_device\n        batch_size = image.shape[0]\n        if isinstance(prompt, list):\n            if batch_size == 1:\n                image = image.expand(len(prompt), -1, -1, -1)\n        elif isinstance(prompt, str):\n            if batch_size > 1:\n                prompt = [prompt] * batch_size\n\n        # text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            return_tensors=\"pt\"\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]\n        print(\"input text embeddings :\", text_embeddings.shape)\n        # define initial latents\n        latents = self.images2latents(image)\n\n        print(\"latents shape: \", latents.shape)\n        # interative sampling\n        self.scheduler.set_timesteps(num_inference_steps)\n        print(\"Valid timesteps: \", reversed(self.scheduler.timesteps))\n        latents_list = [latents]\n        pred_x0_list = [latents]\n        for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc=\"DDIM Inversion\")):\n\n            if num_actual_inference_steps is not None and i >= num_actual_inference_steps:\n                continue\n            model_inputs = latents\n\n            # predict the noise\n            # NOTE: the u-net here is UNet3D, therefore the model_inputs need to be of shape (b c f h w)\n            model_inputs = rearrange(model_inputs, \"f c h w -> 1 c f h w\")\n            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample\n            noise_pred = rearrange(noise_pred, \"b c f h w -> (b f) c h w\")\n            \n            # compute the previous noise sample x_t-1 -> x_t\n            latents, pred_x0 = self.next_step(noise_pred, t, latents)\n            latents_list.append(latents)\n            pred_x0_list.append(pred_x0)\n\n        if return_intermediates:\n            # return the intermediate laters during inversion\n            return latents, latents_list\n        return latents\n    \n    def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, device ):\n        if interpolation_factor < 2:\n            return latents\n\n        new_latents = torch.zeros(\n                    (latents.shape[0],latents.shape[1],((latents.shape[2]-1) * interpolation_factor)+1, latents.shape[3],latents.shape[4]),\n                    device=latents.device,\n                    dtype=latents.dtype,\n                )\n\n        org_video_length = latents.shape[2]\n        rate = [i/interpolation_factor for i in range(interpolation_factor)][1:]\n\n        new_index = 0\n\n        v0 = None\n        v1 = None\n\n        for i0,i1 in zip( range( org_video_length ),range( org_video_length )[1:] ):\n            v0 = latents[:,:,i0,:,:]\n            v1 = latents[:,:,i1,:,:]\n\n            new_latents[:,:,new_index,:,:] = v0\n            new_index += 1\n\n            for f in rate:\n                v = get_tensor_interpolation_method()(v0.to(device=device),v1.to(device=device),f)\n                new_latents[:,:,new_index,:,:] = v.to(latents.device)\n                new_index += 1\n\n        new_latents[:,:,new_index,:,:] = v1\n        new_index += 1\n\n        return new_latents\n    \n    def select_controlnet_res_samples(self, controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f):\n        _down_block_res_samples = []\n        _mid_block_res_sample = []\n        for i in np.concatenate(np.array(context)):\n            _down_block_res_samples.append(controlnet_res_samples_cache_dict[i][0])\n            _mid_block_res_sample.append(controlnet_res_samples_cache_dict[i][1])\n        down_block_res_samples = [[] for _ in range(len(controlnet_res_samples_cache_dict[i][0]))]\n        for res_t in _down_block_res_samples:\n            for i, res in enumerate(res_t):\n                down_block_res_samples[i].append(res)\n        down_block_res_samples = [torch.cat(res) for res in down_block_res_samples]\n        mid_block_res_sample = torch.cat(_mid_block_res_sample)\n        \n        # reshape controlnet output to match the unet3d inputs\n        b = b // 2 if do_classifier_free_guidance else b\n        _down_block_res_samples = []\n        for sample in down_block_res_samples:\n            sample = rearrange(sample, '(b f) c h w -> b c f h w', b=b, f=f)\n            if do_classifier_free_guidance:\n                sample = sample.repeat(2, 1, 1, 1, 1)\n            _down_block_res_samples.append(sample)\n        down_block_res_samples = _down_block_res_samples\n        mid_block_res_sample = rearrange(mid_block_res_sample, '(b f) c h w -> b c f h w', b=b, f=f)\n        if do_classifier_free_guidance:\n            mid_block_res_sample = mid_block_res_sample.repeat(2, 1, 1, 1, 1)\n            \n        return down_block_res_samples, mid_block_res_sample\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        video_length: Optional[int],\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"tensor\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: Optional[int] = 1,\n        controlnet_condition: list = None,\n        controlnet_conditioning_scale: float = 1.0,\n        context_frames: int = 16,\n        context_stride: int = 1,\n        context_overlap: int = 4,\n        context_batch_size: int = 1, \n        context_schedule: str = \"uniform\",\n        init_latents: Optional[torch.FloatTensor] = None,\n        num_actual_inference_steps: Optional[int] = None,\n        appearance_encoder = None, \n        reference_control_writer = None,\n        reference_control_reader = None,\n        source_image: str = None,\n        decoder_consistency = None, \n        **kwargs,\n    ):\n        \"\"\"\n        New args:\n        - controlnet_condition          : condition map (e.g., depth, canny, keypoints) for controlnet\n        - controlnet_conditioning_scale : conditioning scale for controlnet\n        - init_latents                  : initial latents to begin with (used along with invert())\n        - num_actual_inference_steps    : number of actual inference steps (while total steps is num_inference_steps) \n        \"\"\"\n        controlnet = self.controlnet\n\n        # Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # Check inputs. Raise error if not correct\n        self.check_inputs(prompt, height, width, callback_steps)\n\n        # Define call parameters\n        # batch_size = 1 if isinstance(prompt, str) else len(prompt)\n        batch_size = 1\n        if latents is not None:\n            batch_size = latents.shape[0]\n        if isinstance(prompt, list):\n            batch_size = len(prompt)\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # Encode input prompt\n        prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size\n        if negative_prompt is not None:\n            negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size \n        text_embeddings = self._encode_prompt(\n            prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt\n        )\n        text_embeddings = torch.cat([text_embeddings] * context_batch_size)\n        \n        reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', batch_size=context_batch_size)\n        reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', batch_size=context_batch_size)\n        \n        is_dist_initialized = kwargs.get(\"dist\", False)\n        rank = kwargs.get(\"rank\", 0)\n        world_size = kwargs.get(\"world_size\", 1)\n\n        # Prepare video\n        assert num_videos_per_prompt == 1   # FIXME: verify if num_videos_per_prompt > 1 works\n        assert batch_size == 1              # FIXME: verify if batch_size > 1 works\n        control = self.prepare_condition(\n                condition=controlnet_condition,\n                device=device,\n                dtype=controlnet.dtype,\n                num_videos_per_prompt=num_videos_per_prompt,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n            )\n        controlnet_uncond_images, controlnet_cond_images = control.chunk(2)\n\n        # Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # Prepare latent variables\n        if init_latents is not None:\n            latents = rearrange(init_latents, \"(b f) c h w -> b c f h w\", f=video_length)\n        else:\n            num_channels_latents = self.unet.in_channels\n            latents = self.prepare_latents(\n                batch_size * num_videos_per_prompt,\n                num_channels_latents,\n                video_length,\n                height,\n                width,\n                text_embeddings.dtype,\n                device,\n                generator,\n                latents,\n            )\n        latents_dtype = latents.dtype\n\n        # Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # Prepare text embeddings for controlnet\n        controlnet_text_embeddings = text_embeddings.repeat_interleave(video_length, 0)\n        _, controlnet_text_embeddings_c = controlnet_text_embeddings.chunk(2)\n        \n        controlnet_res_samples_cache_dict = {i:None for i in range(video_length)}\n\n        # For img2img setting\n        if num_actual_inference_steps is None:\n            num_actual_inference_steps = num_inference_steps\n        \n        if isinstance(source_image, str):\n            ref_image_latents = self.images2latents(np.array(Image.open(source_image).resize((width, height)))[None, :], latents_dtype).cuda()\n        elif isinstance(source_image, np.ndarray):\n            ref_image_latents = self.images2latents(source_image[None, :], latents_dtype).cuda()\n        \n        context_scheduler = get_context_scheduler(context_schedule)\n        \n        # Denoising loop\n        for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank!=0)):\n            if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:\n                continue\n\n            noise_pred = torch.zeros(\n                (latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:]),\n                device=latents.device,\n                dtype=latents.dtype,\n            )\n            counter = torch.zeros(\n                (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype\n            )\n\n            appearance_encoder(\n                ref_image_latents.repeat(context_batch_size * (2 if do_classifier_free_guidance else 1), 1, 1, 1),\n                t,\n                encoder_hidden_states=text_embeddings,\n                return_dict=False,\n            )\n            \n            context_queue = list(context_scheduler(\n                0, num_inference_steps, latents.shape[2], context_frames, context_stride, 0\n            ))\n            num_context_batches = math.ceil(len(context_queue) / context_batch_size)\n            for i in range(num_context_batches):\n                context = context_queue[i*context_batch_size: (i+1)*context_batch_size]\n                # expand the latents if we are doing classifier free guidance\n                controlnet_latent_input = (\n                    torch.cat([latents[:, :, c] for c in context])\n                    .to(device)\n                )\n                controlnet_latent_input = self.scheduler.scale_model_input(controlnet_latent_input, t)\n\n                # prepare inputs for controlnet\n                b, c, f, h, w = controlnet_latent_input.shape\n                controlnet_latent_input = rearrange(controlnet_latent_input, \"b c f h w -> (b f) c h w\")\n                \n                # controlnet inference\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    controlnet_latent_input,\n                    t,\n                    encoder_hidden_states=torch.cat([controlnet_text_embeddings_c[c] for c in context]),\n                    controlnet_cond=torch.cat([controlnet_cond_images[c] for c in context]),\n                    conditioning_scale=controlnet_conditioning_scale,\n                    return_dict=False,\n                )\n\n                for j, k in enumerate(np.concatenate(np.array(context))):\n                    controlnet_res_samples_cache_dict[k] = ([sample[j:j+1] for sample in down_block_res_samples], mid_block_res_sample[j:j+1])\n\n            context_queue = list(context_scheduler(\n                0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap\n            ))\n\n            num_context_batches = math.ceil(len(context_queue) / context_batch_size)\n            global_context = []\n            for i in range(num_context_batches):\n                global_context.append(context_queue[i*context_batch_size: (i+1)*context_batch_size])\n            \n            for context in global_context[rank::world_size]:\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = (\n                    torch.cat([latents[:, :, c] for c in context])\n                    .to(device)\n                    .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)\n                )\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                b, c, f, h, w = latent_model_input.shape\n                down_block_res_samples, mid_block_res_sample = self.select_controlnet_res_samples(\n                    controlnet_res_samples_cache_dict, \n                    context,\n                    do_classifier_free_guidance,\n                    b, f\n                )\n                \n                reference_control_reader.update(reference_control_writer)\n                \n                # predict the noise residual\n                pred = self.unet(\n                    latent_model_input, \n                    t, \n                    encoder_hidden_states=text_embeddings[:b],\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    return_dict=False,\n                )[0]\n                \n                reference_control_reader.clear()\n                \n                pred_uc, pred_c = pred.chunk(2)\n                pred = torch.cat([pred_uc.unsqueeze(0), pred_c.unsqueeze(0)])\n                for j, c in enumerate(context):\n                    noise_pred[:, :, c] = noise_pred[:, :, c] + pred[:, j]\n                    counter[:, :, c] = counter[:, :, c] + 1\n                    \n            if is_dist_initialized:\n                noise_pred_gathered = [torch.zeros_like(noise_pred) for _ in range(world_size)]\n                if rank == 0:\n                    dist.gather(tensor=noise_pred, gather_list=noise_pred_gathered, dst=0)\n                else:\n                    dist.gather(tensor=noise_pred, gather_list=[], dst=0)\n                dist.barrier()\n\n                if rank == 0:\n                    for k in range(1, world_size):\n                        for context in global_context[k::world_size]:\n                            for j, c in enumerate(context):\n                                noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_gathered[k][:, :, c] \n                                counter[:, :, c] = counter[:, :, c] + 1\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n            \n            if is_dist_initialized:\n                dist.broadcast(latents, 0)\n                dist.barrier()\n            \n            reference_control_writer.clear()\n\n        interpolation_factor = 1\n        latents = self.interpolate_latents(latents, interpolation_factor, device)\n        # Post-processing\n        video = self.decode_latents(latents, rank, decoder_consistency=decoder_consistency)\n\n        if is_dist_initialized:\n            dist.barrier()\n\n        # Convert to tensor\n        if output_type == \"tensor\":\n            video = torch.from_numpy(video)\n\n        if not return_dict:\n            return video\n        \n        return AnimationPipelineOutput(videos=video)\n"
  },
  {
    "path": "magicanimate/utils/dist_tools.py",
    "content": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport os\nimport socket\nimport warnings\nimport torch\nfrom torch import distributed as dist\n\n\ndef distributed_init(args):\n\n    if dist.is_initialized():\n        warnings.warn(\"Distributed is already initialized, cannot initialize twice!\")\n        args.rank = dist.get_rank()\n    else:\n        print(\n            f\"Distributed Init (Rank {args.rank}): \"\n            f\"{args.init_method}\"\n        )\n        dist.init_process_group(\n            backend='nccl',\n            init_method=args.init_method,\n            world_size=args.world_size,\n            rank=args.rank,\n        )\n        print(\n            f\"Initialized Host {socket.gethostname()} as Rank \"\n            f\"{args.rank}\"\n        )\n\n        if \"MASTER_ADDR\" not in os.environ or \"MASTER_PORT\" not in os.environ:\n            # Set for onboxdataloader support\n            split = args.init_method.split(\"//\")\n            assert len(split) == 2, (\n                \"host url for distributed should be split by '//' \"\n                + \"into exactly two elements\"\n            )\n\n            split = split[1].split(\":\")\n            assert (\n                len(split) == 2\n            ), \"host url should be of the form <host_url>:<host_port>\"\n            os.environ[\"MASTER_ADDR\"] = split[0]\n            os.environ[\"MASTER_PORT\"] = split[1]\n\n        # perform a dummy all-reduce to initialize the NCCL communicator\n        dist.all_reduce(torch.zeros(1).cuda())\n\n        suppress_output(is_master())\n        args.rank = dist.get_rank()\n    return args.rank\n\n\ndef get_rank():\n    if not dist.is_available():\n        return 0\n    if not dist.is_nccl_available():\n        return 0\n    if not dist.is_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_master():\n    return get_rank() == 0\n\n\ndef synchronize():\n    if dist.is_initialized():\n        dist.barrier()\n\n\ndef suppress_output(is_master):\n    \"\"\"Suppress printing on the current device. Force printing with `force=True`.\"\"\"\n    import builtins as __builtin__\n\n    builtin_print = __builtin__.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop(\"force\", False)\n        if is_master or force:\n            builtin_print(*args, **kwargs)\n\n    __builtin__.print = print\n\n    import warnings\n\n    builtin_warn = warnings.warn\n\n    def warn(*args, **kwargs):\n        force = kwargs.pop(\"force\", False)\n        if is_master or force:\n            builtin_warn(*args, **kwargs)\n\n    # Log warnings only once\n    warnings.warn = warn\n    warnings.simplefilter(\"once\", UserWarning)"
  },
  {
    "path": "magicanimate/utils/util.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Adapted from https://github.com/guoyww/AnimateDiff\nimport os\nimport imageio\nimport numpy as np\n\nimport torch\nimport torchvision\n\nfrom PIL import Image\nfrom typing import Union\nfrom tqdm import tqdm\nfrom einops import rearrange\n\n\ndef save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25):\n    videos = rearrange(videos, \"b c t h w -> t b c h w\")\n    outputs = []\n    for x in videos:\n        x = torchvision.utils.make_grid(x, nrow=n_rows)\n        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)\n        if rescale:\n            x = (x + 1.0) / 2.0  # -1,1 -> 0,1\n        x = (x * 255).numpy().astype(np.uint8)\n        outputs.append(x)\n\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    imageio.mimsave(path, outputs, fps=fps)\n\ndef save_images_grid(images: torch.Tensor, path: str):\n    assert images.shape[2] == 1 # no time dimension\n    images = images.squeeze(2)\n    grid = torchvision.utils.make_grid(images)\n    grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8)\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    Image.fromarray(grid).save(path)\n\n# DDIM Inversion\n@torch.no_grad()\ndef init_prompt(prompt, pipeline):\n    uncond_input = pipeline.tokenizer(\n        [\"\"], padding=\"max_length\", max_length=pipeline.tokenizer.model_max_length,\n        return_tensors=\"pt\"\n    )\n    uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]\n    text_input = pipeline.tokenizer(\n        [prompt],\n        padding=\"max_length\",\n        max_length=pipeline.tokenizer.model_max_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]\n    context = torch.cat([uncond_embeddings, text_embeddings])\n\n    return context\n\n\ndef next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,\n              sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):\n    timestep, next_timestep = min(\n        timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep\n    alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod\n    alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]\n    beta_prod_t = 1 - alpha_prod_t\n    next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5\n    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output\n    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction\n    return next_sample\n\n\ndef get_noise_pred_single(latents, t, context, unet):\n    noise_pred = unet(latents, t, encoder_hidden_states=context)[\"sample\"]\n    return noise_pred\n\n\n@torch.no_grad()\ndef ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):\n    context = init_prompt(prompt, pipeline)\n    uncond_embeddings, cond_embeddings = context.chunk(2)\n    all_latent = [latent]\n    latent = latent.clone().detach()\n    for i in tqdm(range(num_inv_steps)):\n        t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]\n        noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)\n        latent = next_step(noise_pred, t, latent, ddim_scheduler)\n        all_latent.append(latent)\n    return all_latent\n\n\n@torch.no_grad()\ndef ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=\"\"):\n    ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)\n    return ddim_latents\n\n\ndef video2images(path, step=4, length=16, start=0):\n    reader = imageio.get_reader(path)\n    frames = []\n    for frame in reader:\n        frames.append(np.array(frame))\n    frames = frames[start::step][:length]\n    return frames\n\n\ndef images2video(video, path, fps=8):\n    imageio.mimsave(path, video, fps=fps)\n    return\n\n\ntensor_interpolation = None\n\ndef get_tensor_interpolation_method():\n    return tensor_interpolation\n\ndef set_tensor_interpolation_method(is_slerp):\n    global tensor_interpolation\n    tensor_interpolation = slerp if is_slerp else linear\n\ndef linear(v1, v2, t):\n    return (1.0 - t) * v1 + t * v2\n\ndef slerp(\n    v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995\n) -> torch.Tensor:\n    u0 = v0 / v0.norm()\n    u1 = v1 / v1.norm()\n    dot = (u0 * u1).sum()\n    if dot.abs() > DOT_THRESHOLD:\n        #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')\n        return (1.0 - t) * v0 + t * v1\n    omega = dot.acos()\n    return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()"
  },
  {
    "path": "magicanimate/utils/videoreader.py",
    "content": "# *************************************************************************\n# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-\n# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-\n# ytedance Inc..  \n# *************************************************************************\n\n# Copyright 2022 ByteDance and/or its affiliates.\n#\n# Copyright (2022) PV3D Authors\n#\n# ByteDance, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from ByteDance or\n# its affiliates is strictly prohibited.\nimport av, gc\nimport torch\nimport warnings\nimport numpy as np\n\n\n_CALLED_TIMES = 0\n_GC_COLLECTION_INTERVAL = 20\n\n\n# remove warnings\nav.logging.set_level(av.logging.ERROR)\n\n\nclass VideoReader():\n    \"\"\"\n    Simple wrapper around PyAV that exposes a few useful functions for\n    dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.\n    Acknowledgement: Codes are borrowed from Bruno Korbar\n    \"\"\"\n    def __init__(self, video, num_frames=float(\"inf\"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):\n        \"\"\"\n        Arguments:\n            video_path (str): path or byte of the video to be loaded\n        \"\"\"\n        self.container = av.open(video)\n        self.num_frames = num_frames\n        self.bi_frame = bi_frame\n        \n        self.resampler = None\n        if audio_resample_rate is not None:\n            self.resampler = av.AudioResampler(rate=audio_resample_rate)\n        \n        if self.container.streams.video:\n            # enable multi-threaded video decoding\n            if decode_lossy:\n                warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)\n                self.container.streams.video[0].thread_type = 'AUTO'\n            self.video_stream = self.container.streams.video[0]\n        else:\n            self.video_stream = None\n        \n        self.fps = self._get_video_frame_rate()\n\n    def seek(self, pts, backward=True, any_frame=False):\n        stream = self.video_stream\n        self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)\n\n    def _occasional_gc(self):\n        # there are a lot of reference cycles in PyAV, so need to manually call\n        # the garbage collector from time to time\n        global _CALLED_TIMES, _GC_COLLECTION_INTERVAL\n        _CALLED_TIMES += 1\n        if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:\n            gc.collect()\n\n    def _read_video(self, offset):\n        self._occasional_gc()\n\n        pts = self.container.duration * offset\n        time_ = pts / float(av.time_base)\n        self.container.seek(int(pts))\n\n        video_frames = []\n        count = 0\n        for _, frame in enumerate(self._iter_frames()):\n            if frame.pts * frame.time_base >= time_:\n                video_frames.append(frame)\n                if count >= self.num_frames - 1:\n                    break\n                count += 1\n        return video_frames\n\n    def _iter_frames(self):\n        for packet in self.container.demux(self.video_stream):\n            for frame in packet.decode():\n                yield frame\n\n    def _compute_video_stats(self):\n        if self.video_stream is None or self.container is None:\n            return 0\n        num_of_frames = self.container.streams.video[0].frames\n        if num_of_frames == 0:\n            num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)\n        self.seek(0, backward=False)\n        count = 0\n        time_base = 512\n        for p in self.container.decode(video=0):\n            count = count + 1\n            if count == 1:\n                start_pts = p.pts\n            elif count == 2:\n                time_base = p.pts - start_pts\n                break\n        return start_pts, time_base, num_of_frames\n    \n    def _get_video_frame_rate(self):\n        return float(self.container.streams.video[0].guessed_rate)\n    \n    def sample(self, debug=False):\n        \n        if self.container is None:\n            raise RuntimeError('video stream not found')\n        sample = dict()\n        _, _, total_num_frames = self._compute_video_stats()\n        offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()\n        video_frames = self._read_video(offset/total_num_frames)\n        video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])\n        sample[\"frames\"] = video_frames\n        sample[\"frame_idx\"] = [offset]\n\n        if self.bi_frame:\n            frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]\n            frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]\n            frames.sort()\n            video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])\n            Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]\n            sample[\"frames\"] = video_frames\n            sample[\"real_t\"] = torch.tensor(Ts, dtype=torch.float32)\n            sample[\"frame_idx\"] = [offset+min(frames), offset+max(frames)]\n            return sample\n\n        return sample\n\n    def read_frames(self, frame_indices):\n        self.num_frames = frame_indices[1] - frame_indices[0]\n        video_frames = self._read_video(frame_indices[0]/self.get_num_frames())\n        video_frames = np.array([\n            np.uint8(video_frames[0].to_rgb().to_ndarray()),\n            np.uint8(video_frames[-1].to_rgb().to_ndarray())\n        ])\n        return video_frames\n\n    def read(self):\n        video_frames = self._read_video(0)\n        video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])\n        return video_frames\n    \n    def get_num_frames(self):\n        _, _, total_num_frames = self._compute_video_stats()\n        return total_num_frames"
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==1.4.0\naccelerate==0.22.0\naiofiles==23.2.1\naiohttp==3.8.5\naiosignal==1.3.1\naltair==5.0.1\nannotated-types==0.5.0\nantlr4-python3-runtime==4.9.3\nanyio==3.7.1\nasync-timeout==4.0.3\nattrs==23.1.0\ncachetools==5.3.1\ncertifi==2023.7.22\ncharset-normalizer==3.2.0\nclick==8.1.7\ncmake==3.27.2\ncontourpy==1.1.0\ncycler==0.11.0\ndatasets==2.14.4\ndill==0.3.7\neinops==0.6.1\nexceptiongroup==1.1.3\nfastapi==0.103.0\nffmpy==0.3.1\nfilelock==3.12.2\nfonttools==4.42.1\nfrozenlist==1.4.0\nfsspec==2023.6.0\ngoogle-auth==2.22.0\ngoogle-auth-oauthlib==1.0.0\ngradio==3.41.2\ngradio-client==0.5.0\ngrpcio==1.57.0\nh11==0.14.0\nhttpcore==0.17.3\nhttpx==0.24.1\nhuggingface-hub==0.16.4\nidna==3.4\nimportlib-metadata==6.8.0\nimportlib-resources==6.0.1\njinja2==3.1.2\njoblib==1.3.2\njsonschema==4.19.0\njsonschema-specifications==2023.7.1\nkiwisolver==1.4.5\nlightning-utilities==0.9.0\nlit==16.0.6\nmarkdown==3.4.4\nmarkupsafe==2.1.3\nmatplotlib==3.7.2\nmpmath==1.3.0\nmultidict==6.0.4\nmultiprocess==0.70.15\nnetworkx==3.1\nnumpy==1.24.4\nnvidia-cublas-cu11==11.10.3.66\nnvidia-cuda-cupti-cu11==11.7.101\nnvidia-cuda-nvrtc-cu11==11.7.99\nnvidia-cuda-runtime-cu11==11.7.99\nnvidia-cudnn-cu11==8.5.0.96\nnvidia-cufft-cu11==10.9.0.58\nnvidia-curand-cu11==10.2.10.91\nnvidia-cusolver-cu11==11.4.0.1\nnvidia-cusparse-cu11==11.7.4.91\nnvidia-nccl-cu11==2.14.3\nnvidia-nvtx-cu11==11.7.91\noauthlib==3.2.2\nomegaconf==2.3.0\nopencv-python==4.8.0.76\norjson==3.9.5\npandas==2.0.3\npillow==9.5.0\npkgutil-resolve-name==1.3.10\nprotobuf==4.24.2\npsutil==5.9.5\npyarrow==13.0.0\npyasn1==0.5.0\npyasn1-modules==0.3.0\npydantic==2.3.0\npydantic-core==2.6.3\npydub==0.25.1\npyparsing==3.0.9\npython-multipart==0.0.6\npytorch-lightning==2.0.7\npytz==2023.3\npyyaml==6.0.1\nreferencing==0.30.2\nregex==2023.8.8\nrequests==2.31.0\nrequests-oauthlib==1.3.1\nrpds-py==0.9.2\nrsa==4.9\nsafetensors==0.3.3\nsemantic-version==2.10.0\nsniffio==1.3.0\nstarlette==0.27.0\nsympy==1.12\ntensorboard==2.14.0\ntensorboard-data-server==0.7.1\ntokenizers==0.13.3\ntoolz==0.12.0\ntorchmetrics==1.1.0\ntqdm==4.66.1\ntransformers==4.32.0\ntriton==2.0.0\ntzdata==2023.3\nurllib3==1.26.16\nuvicorn==0.23.2\nwebsockets==11.0.3\nwerkzeug==2.3.7\nxxhash==3.3.0\nyarl==1.9.2\nzipp==3.16.2\ndecord\nimageio==2.9.0\nimageio-ffmpeg==0.4.3\ntimm\nscipy\nscikit-image\nav\nimgaug\nlpips\nffmpeg-python\ntorch==2.0.1\ntorchvision==0.15.2\nxformers==0.0.22\ndiffusers==0.21.4\n"
  },
  {
    "path": "scripts/animate.sh",
    "content": "python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml\n"
  },
  {
    "path": "scripts/animate_dist.sh",
    "content": "python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml --dist\n"
  }
]