Repository: magic-research/magic-animate Branch: main Commit: d2bc3bc3c9cc Files: 31 Total size: 417.0 KB Directory structure: gitextract_ig9mhcve/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── inference/ │ │ └── inference.yaml │ └── prompts/ │ └── animation.yaml ├── demo/ │ ├── animate.py │ ├── animate_dist.py │ ├── gradio_animate.py │ └── gradio_animate_dist.py ├── environment.yaml ├── magicanimate/ │ ├── models/ │ │ ├── appearance_encoder.py │ │ ├── attention.py │ │ ├── controlnet.py │ │ ├── embeddings.py │ │ ├── motion_module.py │ │ ├── mutual_self_attention.py │ │ ├── orig_attention.py │ │ ├── resnet.py │ │ ├── stable_diffusion_controlnet_reference.py │ │ ├── unet.py │ │ ├── unet_3d_blocks.py │ │ └── unet_controlnet.py │ ├── pipelines/ │ │ ├── animation.py │ │ ├── context.py │ │ └── pipeline_animation.py │ └── utils/ │ ├── dist_tools.py │ ├── util.py │ └── videoreader.py ├── requirements.txt └── scripts/ ├── animate.sh └── animate_dist.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__ .vscode samples xformers src third_party backup pretrained_models *.nfs* ./*.png ./*.mp4 demo/tmp demo/outputs ================================================ FILE: LICENSE ================================================ BSD 3-Clause License Copyright 2023 MagicAnimate Team All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================

MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model

Zhongcong Xu · Jianfeng Zhang · Jun Hao Liew · Hanshu Yan · Jia-Wei Liu · Chenxu Zhang · Jiashi Feng · Mike Zheng Shou

Paper PDF Project Page
National University of Singapore   |   ByteDance

## 📢 News * **[2023.12.4]** Release inference code and gradio demo. We are working to improve MagicAnimate, stay tuned! * **[2023.11.23]** Release MagicAnimate paper and project page. ## 🏃‍♂️ Getting Started Download the pretrained base models for [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [MSE-finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse). Download our MagicAnimate [checkpoints](https://huggingface.co/zcxu-eric/MagicAnimate). Please follow the huggingface download instructions to download the above models and checkpoints, `git lfs` is recommended. Place the based models and checkpoints as follows: ```bash magic-animate |----pretrained_models |----MagicAnimate |----appearance_encoder |----diffusion_pytorch_model.safetensors |----config.json |----densepose_controlnet |----diffusion_pytorch_model.safetensors |----config.json |----temporal_attention |----temporal_attention.ckpt |----sd-vae-ft-mse |----config.json |----diffusion_pytorch_model.safetensors |----stable-diffusion-v1-5 |----scheduler |----scheduler_config.json |----text_encoder |----config.json |----pytorch_model.bin |----tokenizer (all) |----unet |----diffusion_pytorch_model.bin |----config.json |----v1-5-pruned-emaonly.safetensors |----... ``` ## ⚒️ Installation prerequisites: `python>=3.8`, `CUDA>=11.3`, and `ffmpeg`. Install with `conda`: ```bash conda env create -f environment.yaml conda activate manimate ``` or `pip`: ```bash pip3 install -r requirements.txt ``` ## 💃 Inference Run inference on single GPU: ```bash bash scripts/animate.sh ``` Run inference with multiple GPUs: ```bash bash scripts/animate_dist.sh ``` ## 🎨 Gradio Demo #### Online Gradio Demo: Try our [online gradio demo](https://huggingface.co/spaces/zcxu-eric/magicanimate) quickly. #### Local Gradio Demo: Launch local gradio demo on single GPU: ```bash python3 -m demo.gradio_animate ``` Launch local gradio demo if you have multiple GPUs: ```bash python3 -m demo.gradio_animate_dist ``` Then open gradio demo in local browser. ## 🙏 Acknowledgements We would like to thank [AK(@_akhaliq)](https://twitter.com/_akhaliq?lang=en) and huggingface team for the help of setting up oneline gradio demo. ## 🎓 Citation If you find this codebase useful for your research, please use the following entry. ```BibTeX @inproceedings{xu2023magicanimate, author = {Xu, Zhongcong and Zhang, Jianfeng and Liew, Jun Hao and Yan, Hanshu and Liu, Jia-Wei and Zhang, Chenxu and Feng, Jiashi and Shou, Mike Zheng}, title = {MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model}, booktitle = {arXiv}, year = {2023} } ``` ================================================ FILE: configs/inference/inference.yaml ================================================ unet_additional_kwargs: unet_use_cross_frame_attention: false unet_use_temporal_attention: false use_motion_module: true motion_module_resolutions: - 1 - 2 - 4 - 8 motion_module_mid_block: false motion_module_decoder_only: false motion_module_type: Vanilla motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: - Temporal_Self - Temporal_Self temporal_position_encoding: true temporal_position_encoding_max_len: 24 temporal_attention_dim_div: 1 noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" ================================================ FILE: configs/prompts/animation.yaml ================================================ pretrained_model_path: "pretrained_models/stable-diffusion-v1-5" pretrained_vae_path: "pretrained_models/sd-vae-ft-mse" pretrained_controlnet_path: "pretrained_models/MagicAnimate/densepose_controlnet" pretrained_appearance_encoder_path: "pretrained_models/MagicAnimate/appearance_encoder" pretrained_unet_path: "" motion_module: "pretrained_models/MagicAnimate/temporal_attention/temporal_attention.ckpt" savename: null fusion_blocks: "midup" seed: [1] steps: 25 guidance_scale: 7.5 source_image: - "inputs/applications/source_image/monalisa.png" - "inputs/applications/source_image/demo4.png" - "inputs/applications/source_image/dalle2.jpeg" - "inputs/applications/source_image/dalle8.jpeg" - "inputs/applications/source_image/multi1_source.png" video_path: - "inputs/applications/driving/densepose/running.mp4" - "inputs/applications/driving/densepose/demo4.mp4" - "inputs/applications/driving/densepose/running2.mp4" - "inputs/applications/driving/densepose/dancing2.mp4" - "inputs/applications/driving/densepose/multi_dancing.mp4" inference_config: "configs/inference/inference.yaml" size: 512 L: 16 S: 1 I: 0 clip: 0 offset: 0 max_length: null video_type: "condition" invert_video: false save_individual_videos: false ================================================ FILE: demo/animate.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import argparse import datetime import inspect import os import numpy as np from PIL import Image from omegaconf import OmegaConf from collections import OrderedDict import torch from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler from tqdm import tqdm from transformers import CLIPTextModel, CLIPTokenizer from magicanimate.models.unet_controlnet import UNet3DConditionModel from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid from accelerate.utils import set_seed from magicanimate.utils.videoreader import VideoReader from einops import rearrange, repeat import csv, pdb, glob from safetensors import safe_open import math from pathlib import Path class MagicAnimate(): def __init__(self, config="configs/prompts/animation.yaml") -> None: print("Initializing MagicAnimate Pipeline...") *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) config = OmegaConf.load(config) inference_config = OmegaConf.load(config.inference_config) motion_module = config.motion_module ### >>> create animation pipeline >>> ### tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") if config.pretrained_unet_path: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) else: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda() self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) if config.pretrained_vae_path is not None: vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) else: vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") ### Load controlnet controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) vae.to(torch.float16) unet.to(torch.float16) text_encoder.to(torch.float16) controlnet.to(torch.float16) self.appearance_encoder.to(torch.float16) unet.enable_xformers_memory_efficient_attention() self.appearance_encoder.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() self.pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), # NOTE: UniPCMultistepScheduler ).to("cuda") # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(motion_module, map_location="cpu") if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict try: # extra steps for self-trained models state_dict = OrderedDict() for key in motion_module_state_dict.keys(): if key.startswith("module."): _key = key.split("module.")[-1] state_dict[_key] = motion_module_state_dict[key] else: state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): _key = key.split('unet.')[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] missing, unexpected = unet.load_state_dict(_tmp_, strict=False) assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict self.pipeline.to("cuda") self.L = config.L print("Initialization Done!") def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512): prompt = n_prompt = "" random_seed = int(random_seed) step = int(step) guidance_scale = float(guidance_scale) samples_per_video = [] # manually set random seed for reproduction if random_seed != -1: torch.manual_seed(random_seed) set_seed(random_seed) else: torch.seed() if motion_sequence.endswith('.mp4'): control = VideoReader(motion_sequence).read() if control[0].shape[0] != size: control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] control = np.array(control) if source_image.shape[0] != size: source_image = np.array(Image.fromarray(source_image).resize((size, size))) H, W, C = source_image.shape init_latents = None original_length = control.shape[0] if control.shape[0] % self.L > 0: control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') generator = torch.Generator(device=torch.device("cuda:0")) generator.manual_seed(torch.initial_seed()) sample = self.pipeline( prompt, negative_prompt = n_prompt, num_inference_steps = step, guidance_scale = guidance_scale, width = W, height = H, video_length = len(control), controlnet_condition = control, init_latents = init_latents, generator = generator, appearance_encoder = self.appearance_encoder, reference_control_writer = self.reference_control_writer, reference_control_reader = self.reference_control_reader, source_image = source_image, ).videos source_images = np.array([source_image] * original_length) source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 samples_per_video.append(source_images) control = control / 255.0 control = rearrange(control, "t h w c -> 1 c t h w") control = torch.from_numpy(control) samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) samples_per_video = torch.cat(samples_per_video) time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"demo/outputs" animation_path = f"{savedir}/{time_str}.mp4" os.makedirs(savedir, exist_ok=True) save_videos_grid(samples_per_video, animation_path) return animation_path ================================================ FILE: demo/animate_dist.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import argparse import datetime import inspect import os import numpy as np from PIL import Image from omegaconf import OmegaConf from collections import OrderedDict import torch import random from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler from transformers import CLIPTextModel, CLIPTokenizer from magicanimate.models.unet_controlnet import UNet3DConditionModel from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid from magicanimate.utils.dist_tools import distributed_init from accelerate.utils import set_seed from magicanimate.utils.videoreader import VideoReader from einops import rearrange animator = None class MagicAnimate(): def __init__(self, args) -> None: config=args.config device = torch.device(f"cuda:{args.rank}") print("Initializing MagicAnimate Pipeline...") *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) config = OmegaConf.load(config) inference_config = OmegaConf.load(config.inference_config) motion_module = config.motion_module ### >>> create animation pipeline >>> ### tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") if config.pretrained_unet_path: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) else: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) if config.pretrained_vae_path is not None: vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) else: vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") ### Load controlnet controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) vae.to(torch.float16) unet.to(torch.float16) text_encoder.to(torch.float16) controlnet.to(torch.float16) self.appearance_encoder.to(torch.float16) unet.enable_xformers_memory_efficient_attention() self.appearance_encoder.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() self.pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), # NOTE: UniPCMultistepScheduler ) # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(motion_module, map_location="cpu") if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict try: # extra steps for self-trained models state_dict = OrderedDict() for key in motion_module_state_dict.keys(): if key.startswith("module."): _key = key.split("module.")[-1] state_dict[_key] = motion_module_state_dict[key] else: state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): _key = key.split('unet.')[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] missing, unexpected = unet.load_state_dict(_tmp_, strict=False) assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict self.pipeline.to(device) self.L = config.L print("Initialization Done!") dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} self.predict(args.reference_image, args.motion_sequence, args.random_seed, args.step, args.guidance_scale, args.save_path, dist_kwargs) def predict(self, source_image, motion_sequence, random_seed, step, guidance_scale, save_path, dist_kwargs, size=512): prompt = n_prompt = "" samples_per_video = [] # manually set random seed for reproduction if random_seed != -1: torch.manual_seed(random_seed) set_seed(random_seed) else: torch.seed() if motion_sequence.endswith('.mp4'): control = VideoReader(motion_sequence).read() if control[0].shape[0] != size: control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] control = np.array(control) if not isinstance(source_image, np.ndarray): source_image = np.array(Image.open(source_image)) if source_image.shape[0] != size: source_image = np.array(Image.fromarray(source_image).resize((size, size))) H, W, C = source_image.shape init_latents = None original_length = control.shape[0] if control.shape[0] % self.L > 0: control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') generator = torch.Generator(device=torch.device("cuda:0")) generator.manual_seed(torch.initial_seed()) sample = self.pipeline( prompt, negative_prompt = n_prompt, num_inference_steps = step, guidance_scale = guidance_scale, width = W, height = H, video_length = len(control), controlnet_condition = control, init_latents = init_latents, generator = generator, appearance_encoder = self.appearance_encoder, reference_control_writer = self.reference_control_writer, reference_control_reader = self.reference_control_reader, source_image = source_image, **dist_kwargs, ).videos if dist_kwargs.get('rank', 0) == 0: source_images = np.array([source_image] * original_length) source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 samples_per_video.append(source_images) control = control / 255.0 control = rearrange(control, "t h w c -> 1 c t h w") control = torch.from_numpy(control) samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) samples_per_video = torch.cat(samples_per_video) save_videos_grid(samples_per_video, save_path) def distributed_main(device_id, args): args.rank = device_id args.device_id = device_id if torch.cuda.is_available(): torch.cuda.set_device(args.device_id) torch.cuda.init() distributed_init(args) MagicAnimate(args) def run(args): if args.dist: args.world_size = max(1, torch.cuda.device_count()) assert args.world_size <= torch.cuda.device_count() if args.world_size > 0 and torch.cuda.device_count() > 1: port = random.randint(10000, 20000) args.init_method = f"tcp://localhost:{port}" torch.multiprocessing.spawn( fn=distributed_main, args=(args,), nprocs=args.world_size, ) else: MagicAnimate(args) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/prompts/animation.yaml", required=False) parser.add_argument("--dist", type=bool, default=True, required=False) parser.add_argument("--rank", type=int, default=0, required=False) parser.add_argument("--world_size", type=int, default=1, required=False) parser.add_argument("--reference_image", type=str, default=None, required=True) parser.add_argument("--motion_sequence", type=str, default=None, required=True) parser.add_argument("--random_seed", type=int, default=1, required=False) parser.add_argument("--step", type=int, default=25, required=False) parser.add_argument("--guidance_scale", type=float, default=7.5, required=False) parser.add_argument("--save_path", type=str, default=None, required=True) args = parser.parse_args() run(args) ================================================ FILE: demo/gradio_animate.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import imageio import numpy as np import gradio as gr from PIL import Image from demo.animate import MagicAnimate animator = MagicAnimate() def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale): return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale) with gr.Blocks() as demo: gr.HTML( """

MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model

If you like our project, please give us a star ✨ on Github for the latest update.
Project Page
""") animation = gr.Video(format="mp4", label="Animation Results", autoplay=True) with gr.Row(): reference_image = gr.Image(label="Reference Image") motion_sequence = gr.Video(format="mp4", label="Motion Sequence") with gr.Column(): random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1") sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25") guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5") submit = gr.Button("Animate") def read_video(video): reader = imageio.get_reader(video) fps = reader.get_meta_data()['fps'] return video def read_image(image, size=512): return np.array(Image.fromarray(image).resize((size, size))) # when user uploads a new video motion_sequence.upload( read_video, motion_sequence, motion_sequence ) # when `first_frame` is updated reference_image.upload( read_image, reference_image, reference_image ) # when the `submit` button is clicked submit.click( animate, [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], animation ) # Examples gr.Markdown("## Examples") gr.Examples( examples=[ ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"], ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"], ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"], ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"], ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"], ], inputs=[reference_image, motion_sequence], outputs=animation, ) demo.launch() ================================================ FILE: demo/gradio_animate_dist.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import imageio import os, datetime import numpy as np import gradio as gr from PIL import Image from subprocess import PIPE, run os.makedirs("./demo/tmp", exist_ok=True) savedir = f"demo/outputs" os.makedirs(savedir, exist_ok=True) def animate(reference_image, motion_sequence, seed, steps, guidance_scale): time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") animation_path = f"{savedir}/{time_str}.mp4" save_path = "./demo/tmp/input_reference_image.png" Image.fromarray(reference_image).save(save_path) command = "python -m demo.animate_dist --reference_image {} --motion_sequence {} --random_seed {} --step {} --guidance_scale {} --save_path {}".format( save_path, motion_sequence, seed, steps, guidance_scale, animation_path ) run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True) return animation_path with gr.Blocks() as demo: gr.HTML( """

MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model

If you like our project, please give us a star ✨ on Github for the latest update.
Project Page
""") animation = gr.Video(format="mp4", label="Animation Results", autoplay=True) with gr.Row(): reference_image = gr.Image(label="Reference Image") motion_sequence = gr.Video(format="mp4", label="Motion Sequence") with gr.Column(): random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1") sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25") guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5") submit = gr.Button("Animate") def read_video(video, size=512): size = int(size) reader = imageio.get_reader(video) # fps = reader.get_meta_data()['fps'] frames = [] for img in reader: frames.append(np.array(Image.fromarray(img).resize((size, size)))) save_path = "./demo/tmp/input_motion_sequence.mp4" imageio.mimwrite(save_path, frames, fps=25) return save_path def read_image(image, size=512): img = np.array(Image.fromarray(image).resize((size, size))) return img # when user uploads a new video motion_sequence.upload( read_video, motion_sequence, motion_sequence ) # when `first_frame` is updated reference_image.upload( read_image, reference_image, reference_image ) # when the `submit` button is clicked submit.click( animate, [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], animation ) # Examples gr.Markdown("## Examples") gr.Examples( examples=[ ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"], ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"], ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"], ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"], ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"], ], inputs=[reference_image, motion_sequence], outputs=animation, ) # demo.queue(max_size=10) demo.launch() ================================================ FILE: environment.yaml ================================================ name: manimate channels: - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - asttokens=2.2.1=pyhd8ed1ab_0 - backcall=0.2.0=pyh9f0ad1d_0 - backports=1.0=pyhd8ed1ab_3 - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 - ca-certificates=2023.7.22=hbcca054_0 - comm=0.1.4=pyhd8ed1ab_0 - debugpy=1.6.7=py38h6a678d5_0 - decorator=5.1.1=pyhd8ed1ab_0 - entrypoints=0.4=pyhd8ed1ab_0 - executing=1.2.0=pyhd8ed1ab_0 - ipykernel=6.25.1=pyh71e2992_0 - ipython=8.12.0=pyh41d4057_0 - jedi=0.19.0=pyhd8ed1ab_0 - jupyter_client=7.3.4=pyhd8ed1ab_0 - jupyter_core=4.12.0=py38h578d9bd_0 - ld_impl_linux-64=2.38=h1181459_1 - libffi=3.3=he6710b0_2 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libsodium=1.0.18=h36c2ea0_1 - libstdcxx-ng=11.2.0=h1234567_1 - matplotlib-inline=0.1.6=pyhd8ed1ab_0 - ncurses=6.4=h6a678d5_0 - nest-asyncio=1.5.6=pyhd8ed1ab_0 - openssl=1.1.1l=h7f98852_0 - packaging=23.1=pyhd8ed1ab_0 - parso=0.8.3=pyhd8ed1ab_0 - pexpect=4.8.0=pyh1a96a4e_2 - pickleshare=0.7.5=py_1003 - pip=23.2.1=py38h06a4308_0 - prompt-toolkit=3.0.39=pyha770c72_0 - prompt_toolkit=3.0.39=hd8ed1ab_0 - ptyprocess=0.7.0=pyhd3deb0d_0 - pure_eval=0.2.2=pyhd8ed1ab_0 - pygments=2.16.1=pyhd8ed1ab_0 - python=3.8.5=h7579374_1 - python-dateutil=2.8.2=pyhd8ed1ab_0 - python_abi=3.8=2_cp38 - pyzmq=25.1.0=py38h6a678d5_0 - readline=8.2=h5eee18b_0 - setuptools=68.0.0=py38h06a4308_0 - six=1.16.0=pyh6c4a22f_0 - sqlite=3.41.2=h5eee18b_0 - stack_data=0.6.2=pyhd8ed1ab_0 - tk=8.6.12=h1ccaba5_0 - tornado=6.1=py38h0a891b7_3 - traitlets=5.9.0=pyhd8ed1ab_0 - typing_extensions=4.7.1=pyha770c72_0 - wcwidth=0.2.6=pyhd8ed1ab_0 - wheel=0.38.4=py38h06a4308_0 - xz=5.4.2=h5eee18b_0 - zeromq=4.3.4=h9c3ff4c_1 - zlib=1.2.13=h5eee18b_0 - pip: - absl-py==1.4.0 - accelerate==0.22.0 - aiofiles==23.2.1 - aiohttp==3.8.5 - aiosignal==1.3.1 - altair==5.0.1 - annotated-types==0.5.0 - antlr4-python3-runtime==4.9.3 - anyio==3.7.1 - async-timeout==4.0.3 - attrs==23.1.0 - cachetools==5.3.1 - certifi==2023.7.22 - charset-normalizer==3.2.0 - click==8.1.7 - cmake==3.27.2 - contourpy==1.1.0 - cycler==0.11.0 - datasets==2.14.4 - dill==0.3.7 - einops==0.6.1 - exceptiongroup==1.1.3 - fastapi==0.103.0 - ffmpy==0.3.1 - filelock==3.12.2 - fonttools==4.42.1 - frozenlist==1.4.0 - fsspec==2023.6.0 - google-auth==2.22.0 - google-auth-oauthlib==1.0.0 - gradio==3.41.2 - gradio-client==0.5.0 - grpcio==1.57.0 - h11==0.14.0 - httpcore==0.17.3 - httpx==0.24.1 - huggingface-hub==0.16.4 - idna==3.4 - importlib-metadata==6.8.0 - importlib-resources==6.0.1 - jinja2==3.1.2 - joblib==1.3.2 - jsonschema==4.19.0 - jsonschema-specifications==2023.7.1 - kiwisolver==1.4.5 - lightning-utilities==0.9.0 - lit==16.0.6 - markdown==3.4.4 - markupsafe==2.1.3 - matplotlib==3.7.2 - mpmath==1.3.0 - multidict==6.0.4 - multiprocess==0.70.15 - networkx==3.1 - numpy==1.24.4 - nvidia-cublas-cu11==11.10.3.66 - nvidia-cuda-cupti-cu11==11.7.101 - nvidia-cuda-nvrtc-cu11==11.7.99 - nvidia-cuda-runtime-cu11==11.7.99 - nvidia-cudnn-cu11==8.5.0.96 - nvidia-cufft-cu11==10.9.0.58 - nvidia-curand-cu11==10.2.10.91 - nvidia-cusolver-cu11==11.4.0.1 - nvidia-cusparse-cu11==11.7.4.91 - nvidia-nccl-cu11==2.14.3 - nvidia-nvtx-cu11==11.7.91 - oauthlib==3.2.2 - omegaconf==2.3.0 - opencv-python==4.8.0.76 - orjson==3.9.5 - pandas==2.0.3 - pillow==9.5.0 - pkgutil-resolve-name==1.3.10 - protobuf==4.24.2 - psutil==5.9.5 - pyarrow==13.0.0 - pyasn1==0.5.0 - pyasn1-modules==0.3.0 - pydantic==2.3.0 - pydantic-core==2.6.3 - pydub==0.25.1 - pyparsing==3.0.9 - python-multipart==0.0.6 - pytorch-lightning==2.0.7 - pytz==2023.3 - pyyaml==6.0.1 - referencing==0.30.2 - regex==2023.8.8 - requests==2.31.0 - requests-oauthlib==1.3.1 - rpds-py==0.9.2 - rsa==4.9 - safetensors==0.3.3 - semantic-version==2.10.0 - sniffio==1.3.0 - starlette==0.27.0 - sympy==1.12 - tensorboard==2.14.0 - tensorboard-data-server==0.7.1 - tokenizers==0.13.3 - toolz==0.12.0 - torchmetrics==1.1.0 - tqdm==4.66.1 - transformers==4.32.0 - triton==2.0.0 - tzdata==2023.3 - urllib3==1.26.16 - uvicorn==0.23.2 - websockets==11.0.3 - werkzeug==2.3.7 - xxhash==3.3.0 - yarl==1.9.2 - zipp==3.16.2 - decord - imageio==2.9.0 - imageio-ffmpeg==0.4.3 - timm - scipy - scikit-image - av - imgaug - lpips - ffmpeg-python - torch==2.0.1 - torchvision==0.15.2 - xformers==0.0.22 - diffusers==0.21.4 prefix: /home/tiger/miniconda3/envs/manimate ================================================ FILE: magicanimate/models/appearance_encoder.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from diffusers.models.lora import LoRALinearLayer from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, PositionNet, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class Identity(torch.nn.Module): r"""A placeholder identity operator that is argument-insensitive. Args: args: any argument (unused) kwargs: any keyword argument (unused) Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. Examples:: >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 20]) """ def __init__(self, scale=None, *args, **kwargs) -> None: super(Identity, self).__init__() def forward(self, input, *args, **kwargs): return input class _LoRACompatibleLinear(nn.Module): """ A Linear layer that can be used with LoRA. """ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer def _fuse_lora(self): pass def _unfuse_lora(self): pass def forward(self, hidden_states, scale=None, lora_scale: int = 1): return hidden_states @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear() self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear() self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear() self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()]) self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity() self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity() self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity() self.up_blocks[3].attentions[2].proj_out = Identity() if attention_type in ["gated", "gated-text-image"]: positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = PositionNet( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) # To support T2I-Adapter-XL if ( is_adapter and len(down_block_additional_residuals) > 0 and sample.shape == down_block_additional_residuals[0].shape ): sample += down_block_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) ================================================ FILE: magicanimate/models/attention.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import FeedForward, AdaLayerNorm from diffusers.models.attention import Attention as CrossAttention from einops import rearrange, repeat @dataclass class Transformer3DModelOutput(BaseOutput): sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None class Transformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # Define input layers self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) # Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) for d in range(num_layers) ] ) # 4. Define output layers if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # Input assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." video_length = hidden_states.shape[2] hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") # JH: need not repeat when a list of prompts are given if encoder_hidden_states.shape[0] != hidden_states.shape[0]: encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) # Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, video_length=video_length ) # Output if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) output = hidden_states + residual output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) if not return_dict: return (output,) return Transformer3DModelOutput(sample=output) class BasicTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, unet_use_cross_frame_attention = None, unet_use_temporal_attention = None, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = num_embeds_ada_norm is not None self.unet_use_cross_frame_attention = unet_use_cross_frame_attention self.unet_use_temporal_attention = unet_use_temporal_attention # SC-Attn assert unet_use_cross_frame_attention is not None if unet_use_cross_frame_attention: self.attn1 = SparseCausalAttention2D( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) else: self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) # Cross-Attn if cross_attention_dim is not None: self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) else: self.attn2 = None if cross_attention_dim is not None: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) else: self.norm2 = None # Feed-forward self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.norm3 = nn.LayerNorm(dim) self.use_ada_layer_norm_zero = False # Temp-Attn assert unet_use_temporal_attention is not None if unet_use_temporal_attention: self.attn_temp = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) nn.init.zeros_(self.attn_temp.to_out[0].weight.data) self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers", name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" " available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers if self.attn2 is not None: self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): # SparseCausal-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) # if self.only_cross_attention: # hidden_states = ( # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states # ) # else: # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states # pdb.set_trace() if self.unet_use_cross_frame_attention: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states else: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states if self.attn2 is not None: # Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) hidden_states = ( self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) + hidden_states ) # Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states # Temporal-Attention if self.unet_use_temporal_attention: d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) norm_hidden_states = ( self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) ) hidden_states = self.attn_temp(norm_hidden_states) + hidden_states hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states ================================================ FILE: magicanimate/models/controlnet.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from diffusers.models.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class ControlNetConditioningEmbedding(nn.Module): """ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into feature maps ..." """ def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ControlNetModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), ): super().__init__() # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, ) self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=attention_head_dim[i], downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=attention_head_dim[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, ): r""" Instantiate Controlnet class from UNet2DConditionModel. Parameters: unet (`UNet2DConditionModel`): UNet model which weights are copied to the ControlNet. Note that all configuration options are also copied where applicable. """ controlnet = cls( in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) return controlnet # @property # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # def attn_processors(self) -> Dict[str, AttentionProcessor]: # r""" # Returns: # `dict` of attention processors: A dictionary containing all attention processors used in the model with # indexed by its weight name. # """ # # set recursively # processors = {} # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): # if hasattr(module, "set_processor"): # processors[f"{name}.processor"] = module.processor # for sub_name, child in module.named_children(): # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) # return processors # for name, module in self.named_children(): # fn_recursive_add_processors(name, module, processors) # return processors # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # r""" # Parameters: # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): # The instantiated processor class or a dictionary of processor classes that will be set as the processor # of **all** `Attention` layers. # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: # """ # count = len(self.attn_processors.keys()) # if isinstance(processor, dict) and len(processor) != count: # raise ValueError( # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" # f" number of attention layers: {count}. Please make sure to pass {count} processor classes." # ) # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): # if hasattr(module, "set_processor"): # if not isinstance(processor, dict): # module.set_processor(processor) # else: # module.set_processor(processor.pop(f"{name}.processor")) # for sub_name, child in module.named_children(): # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) # for name, module in self.named_children(): # fn_recursive_attn_processor(name, module, processor) # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # def set_default_attn_processor(self): # """ # Disables custom attention processors and sets the default attention implementation. # """ # self.set_attn_processor(AttnProcessor()) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: # check channel order channel_order = self.config.controlnet_conditioning_channel_order if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample += controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) # 5. Control net blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale if not return_dict: return (down_block_res_samples, mid_block_res_sample) return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: magicanimate/models/embeddings.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import numpy as np import torch from torch import nn def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) return latent + self.pos_embed class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None if act_fn == "silu": self.act = nn.SiLU() elif act_fn == "mish": self.act = nn.Mish() elif act_fn == "gelu": self.act = nn.GELU() else: raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None elif post_act_fn == "silu": self.post_act = nn.SiLU() elif post_act_fn == "mish": self.post_act = nn.Mish() elif post_act_fn == "gelu": self.post_act = nn.GELU() else: raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, ) return t_emb class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: Output vector embeddings are used as input for the transformer. Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. Args: num_embed (`int`): Number of embeddings for the latent pixels embeddings. height (`int`): Height of the latent image i.e. the number of height embeddings. width (`int`): Width of the latent image i.e. the number of width embeddings. embed_dim (`int`): Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. """ def __init__( self, num_embed: int, height: int, width: int, embed_dim: int, ): super().__init__() self.height = height self.width = width self.num_embed = num_embed self.embed_dim = embed_dim self.emb = nn.Embedding(self.num_embed, embed_dim) self.height_emb = nn.Embedding(self.height, embed_dim) self.width_emb = nn.Embedding(self.width, embed_dim) def forward(self, index): emb = self.emb(index) height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) # 1 x H x D -> 1 x H x 1 x D height_emb = height_emb.unsqueeze(2) width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) # 1 x W x D -> 1 x 1 x W x D width_emb = width_emb.unsqueeze(1) pos_emb = height_emb + width_emb # 1 x H x W x D -> 1 x L xD pos_emb = pos_emb.view(1, self.height * self.width, -1) emb = emb + pos_emb[:, : emb.shape[1], :] return emb class LabelEmbedding(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. Args: num_classes (`int`): The number of classes. hidden_size (`int`): The size of the vector embeddings. dropout_prob (`float`): The probability of dropping a label. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = torch.tensor(force_drop_ids == 1) labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class_labels = self.class_embedder(class_labels) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning ================================================ FILE: magicanimate/models/motion_module.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/guoyww/AnimateDiff from dataclasses import dataclass import torch import torch.nn.functional as F from torch import nn from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import FeedForward from magicanimate.models.orig_attention import CrossAttention from einops import rearrange, repeat import math def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module @dataclass class TemporalTransformer3DModelOutput(BaseOutput): sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None def get_motion_module( in_channels, motion_module_type: str, motion_module_kwargs: dict ): if motion_module_type == "Vanilla": return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) else: raise ValueError class VanillaTemporalModule(nn.Module): def __init__( self, in_channels, num_attention_heads = 8, num_transformer_block = 2, attention_block_types =( "Temporal_Self", "Temporal_Self" ), cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, temporal_attention_dim_div = 1, zero_initialize = True, ): super().__init__() self.temporal_transformer = TemporalTransformer3DModel( in_channels=in_channels, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, num_layers=num_transformer_block, attention_block_types=attention_block_types, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) if zero_initialize: self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): hidden_states = input_tensor hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) output = hidden_states return output class TemporalTransformer3DModel(nn.Module): def __init__( self, in_channels, num_attention_heads, attention_head_dim, num_layers, attention_block_types = ( "Temporal_Self", "Temporal_Self", ), dropout = 0.0, norm_num_groups = 32, cross_attention_dim = 768, activation_fn = "geglu", attention_bias = False, upcast_attention = False, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ TemporalTransformerBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, attention_block_types=attention_block_types, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, upcast_attention=upcast_attention, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." video_length = hidden_states.shape[2] hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) # Transformer Blocks for block in self.transformer_blocks: hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) # output hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) return output class TemporalTransformerBlock(nn.Module): def __init__( self, dim, num_attention_heads, attention_head_dim, attention_block_types = ( "Temporal_Self", "Temporal_Self", ), dropout = 0.0, norm_num_groups = 32, cross_attention_dim = 768, activation_fn = "geglu", attention_bias = False, upcast_attention = False, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, ): super().__init__() attention_blocks = [] norms = [] for block_name in attention_block_types: attention_blocks.append( VersatileAttention( attention_mode=block_name.split("_")[0], cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) norms.append(nn.LayerNorm(dim)) self.attention_blocks = nn.ModuleList(attention_blocks) self.norms = nn.ModuleList(norms) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.ff_norm = nn.LayerNorm(dim) def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): for attention_block, norm in zip(self.attention_blocks, self.norms): norm_hidden_states = norm(hidden_states) hidden_states = attention_block( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, video_length=video_length, ) + hidden_states hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states output = hidden_states return output class PositionalEncoding(nn.Module): def __init__( self, d_model, dropout = 0., max_len = 24 ): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) class VersatileAttention(CrossAttention): def __init__( self, attention_mode = None, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, *args, **kwargs ): super().__init__(*args, **kwargs) assert attention_mode == "Temporal" self.attention_mode = attention_mode self.is_cross_attention = kwargs["cross_attention_dim"] is not None self.pos_encoder = PositionalEncoding( kwargs["query_dim"], dropout=0., max_len=temporal_position_encoding_max_len ) if (temporal_position_encoding and attention_mode == "Temporal") else None def extra_repr(self): return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): batch_size, sequence_length, _ = hidden_states.shape if self.attention_mode == "Temporal": d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) if self.pos_encoder is not None: hidden_states = self.pos_encoder(hidden_states) encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states else: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) dim = query.shape[-1] query = self.reshape_heads_to_batch_dim(query) if self.added_kv_proj_dim is not None: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value, attention_mask) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) if self.attention_mode == "Temporal": hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states ================================================ FILE: magicanimate/models/mutual_self_attention.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import torch import torch.nn.functional as F from einops import rearrange from typing import Any, Callable, Dict, List, Optional, Tuple, Union from diffusers.models.attention import BasicTransformerBlock from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from .stable_diffusion_controlnet_reference import torch_dfs class AttentionBase: def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 def after_step(self): pass def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.cur_step += 1 # after step self.after_step() return out def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): out = torch.einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) return out def reset(self): self.cur_step = 0 self.cur_att_layer = 0 class MutualSelfAttentionControl(AttentionBase): def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'): """ Mutual self-attention control for Stable-Diffusion MODEl Args: total_steps: the total number of steps """ super().__init__() self.total_steps = total_steps self.hijack = hijack_init_state self.with_negative_guidance = with_negative_guidance # alpha: mutual self attention intensity # TODO: make alpha learnable self.alpha = appearance_control_alpha self.GLOBAL_ATTN_QUEUE = [] assert mode in ['enqueue', 'dequeue'] MODE = mode def attn_batch(self, q, k, v, num_heads, **kwargs): """ Performing attention for a batch of queries, keys, and values """ b = q.shape[0] // num_heads q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") attn = sim.softmax(-1) out = torch.einsum("h i j, h j d -> h i d", attn, v) out = rearrange(out, "h (b n) d -> b n (h d)", b=b) return out def mutual_self_attn(self, q, k, v, num_heads, **kwargs): q_tgt, q_src = q.chunk(2) k_tgt, k_src = k.chunk(2) v_tgt, v_src = v.chunk(2) # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \ # self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha) out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs) out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs) out = torch.cat([out_tgt, out_src], dim=0) return out def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): if self.MODE == 'dequeue' and len(self.kv_queue) > 0: k_src, v_src = self.kv_queue.pop(0) out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs) return out else: self.kv_queue.append([k.clone(), v.clone()]) return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) def get_queue(self): return self.GLOBAL_ATTN_QUEUE def set_queue(self, attn_queue): self.GLOBAL_ATTN_QUEUE = attn_queue def clear_queue(self): self.GLOBAL_ATTN_QUEUE = [] def to(self, dtype): self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE] def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Attention forward function """ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) class ReferenceAttentionControl(): def __init__(self, unet, mode="write", do_classifier_free_guidance=False, attention_auto_machine_weight = float('inf'), gn_auto_machine_weight = 1.0, style_fidelity = 1.0, reference_attn=True, reference_adain=False, fusion_blocks="midup", batch_size=1, ) -> None: # 10. Modify self attention and group norm self.unet = unet assert mode in ["read", "write"] assert fusion_blocks in ["midup", "full"] self.reference_attn = reference_attn self.reference_adain = reference_adain self.fusion_blocks = fusion_blocks self.register_reference_hooks( mode, do_classifier_free_guidance, attention_auto_machine_weight, gn_auto_machine_weight, style_fidelity, reference_attn, reference_adain, fusion_blocks, batch_size=batch_size, ) def register_reference_hooks( self, mode, do_classifier_free_guidance, attention_auto_machine_weight, gn_auto_machine_weight, style_fidelity, reference_attn, reference_adain, dtype=torch.float16, batch_size=1, num_images_per_prompt=1, device=torch.device("cpu"), fusion_blocks='midup', ): MODE = mode do_classifier_free_guidance = do_classifier_free_guidance attention_auto_machine_weight = attention_auto_machine_weight gn_auto_machine_weight = gn_auto_machine_weight style_fidelity = style_fidelity reference_attn = reference_attn reference_adain = reference_adain fusion_blocks = fusion_blocks num_images_per_prompt = num_images_per_prompt dtype=dtype if do_classifier_free_guidance: uc_mask = ( torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16) .to(device) .bool() ) else: uc_mask = ( torch.Tensor([0] * batch_size * num_images_per_prompt * 2) .to(device) .bool() ) def hacked_basic_transformer_inner_forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, video_length=None, ): if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if self.only_cross_attention: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) else: if MODE == "write": self.bank.append(norm_hidden_states.clone()) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if MODE == "read": self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank] hidden_states_uc = self.attn1(norm_hidden_states, encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), attention_mask=attention_mask) + hidden_states hidden_states_c = hidden_states_uc.clone() _uc_mask = uc_mask.clone() if do_classifier_free_guidance: if hidden_states.shape[0] != _uc_mask.shape[0]: _uc_mask = ( torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2)) .to(device) .bool() ) hidden_states_c[_uc_mask] = self.attn1( norm_hidden_states[_uc_mask], encoder_hidden_states=norm_hidden_states[_uc_mask], attention_mask=attention_mask, ) + hidden_states[_uc_mask] hidden_states = hidden_states_c.clone() self.bank.clear() if self.attn2 is not None: # Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) hidden_states = ( self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) + hidden_states ) # Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states # Temporal-Attention if self.unet_use_temporal_attention: d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) norm_hidden_states = ( self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) ) hidden_states = self.attn_temp(norm_hidden_states) + hidden_states hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states def hacked_mid_forward(self, *args, **kwargs): eps = 1e-6 x = self.original_forward(*args, **kwargs) if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append(mean) self.var_bank.append(var) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) var_acc = sum(self.var_bank) / float(len(self.var_bank)) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 x_uc = (((x - mean) / std) * std_acc) + mean_acc x_c = x_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: x_c[uc_mask] = x[uc_mask] x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc self.mean_bank = [] self.var_bank = [] return x def hack_CrossAttnDownBlock2D_forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): eps = 1e-6 # TODO(Patrick, William) - attention mask is not used output_states = () for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc output_states = output_states + (hidden_states,) if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states def hacked_DownBlock2D_forward(self, hidden_states, temb=None): eps = 1e-6 output_states = () for i, resnet in enumerate(self.resnets): hidden_states = resnet(hidden_states, temb) if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc output_states = output_states + (hidden_states,) if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states def hacked_CrossAttnUpBlock2D_forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): eps = 1e-6 # TODO(Patrick, William) - attention mask is not used for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): eps = 1e-6 for i, resnet in enumerate(self.resnets): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states if self.reference_attn: if self.fusion_blocks == "midup": attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] elif self.fusion_blocks == "full": attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) for i, module in enumerate(attn_modules): module._original_inner_forward = module.forward module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) module.bank = [] module.attn_weight = float(i) / float(len(attn_modules)) if self.reference_adain: gn_modules = [self.unet.mid_block] self.unet.mid_block.gn_weight = 0 down_blocks = self.unet.down_blocks for w, module in enumerate(down_blocks): module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) gn_modules.append(module) up_blocks = self.unet.up_blocks for w, module in enumerate(up_blocks): module.gn_weight = float(w) / float(len(up_blocks)) gn_modules.append(module) for i, module in enumerate(gn_modules): if getattr(module, "original_forward", None) is None: module.original_forward = module.forward if i == 0: # mid_block module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) elif isinstance(module, CrossAttnDownBlock2D): module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) elif isinstance(module, DownBlock2D): module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) elif isinstance(module, CrossAttnUpBlock2D): module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) elif isinstance(module, UpBlock2D): module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) module.mean_bank = [] module.var_bank = [] module.gn_weight *= 2 def update(self, writer, dtype=torch.float16): if self.reference_attn: if self.fusion_blocks == "midup": reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)] writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)] elif self.fusion_blocks == "full": reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)] writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)] reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) for r, w in zip(reader_attn_modules, writer_attn_modules): r.bank = [v.clone().to(dtype) for v in w.bank] # w.bank.clear() if self.reference_adain: reader_gn_modules = [self.unet.mid_block] down_blocks = self.unet.down_blocks for w, module in enumerate(down_blocks): reader_gn_modules.append(module) up_blocks = self.unet.up_blocks for w, module in enumerate(up_blocks): reader_gn_modules.append(module) writer_gn_modules = [writer.unet.mid_block] down_blocks = writer.unet.down_blocks for w, module in enumerate(down_blocks): writer_gn_modules.append(module) up_blocks = writer.unet.up_blocks for w, module in enumerate(up_blocks): writer_gn_modules.append(module) for r, w in zip(reader_gn_modules, writer_gn_modules): if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list): r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank] r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank] else: r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank] r.var_bank = [v.clone().to(dtype) for v in w.var_bank] def clear(self): if self.reference_attn: if self.fusion_blocks == "midup": reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] elif self.fusion_blocks == "full": reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)] reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) for r in reader_attn_modules: r.bank.clear() if self.reference_adain: reader_gn_modules = [self.unet.mid_block] down_blocks = self.unet.down_blocks for w, module in enumerate(down_blocks): reader_gn_modules.append(module) up_blocks = self.unet.up_blocks for w, module in enumerate(up_blocks): reader_gn_modules.append(module) for r in reader_gn_modules: r.mean_bank.clear() r.var_bank.clear() ================================================ FILE: magicanimate/models/orig_attention.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.models.embeddings import ImagePositionalEmbeddings from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available @dataclass class Transformer2DModelOutput(BaseOutput): """ Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None class Transformer2DModel(ModelMixin, ConfigMixin): """ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image. When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict classes of unnoised image. Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = in_channels is not None self.is_input_vectorized = num_vector_embeds is not None if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized: raise ValueError( f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is not None." ) # 2. Define input layers if self.is_input_continuous: self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = sample_size self.width = sample_size self.num_vector_embeds = num_vector_embeds self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) for d in range(num_layers) ] ) # 4. Define output layers if self.is_input_continuous: if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. Input if self.is_input_continuous: batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) # 3. Output if self.is_input_continuous: if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. Uses three q, k, v linear layers to compute attention. Parameters: channels (`int`): The number of channels in the input and output. num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore def __init__( self, channels: int, num_head_channels: Optional[int] = None, norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): super().__init__() self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) self.key = nn.Linear(channels, channels) self.value = nn.Linear(channels, channels) self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, 1) self._use_memory_efficient_attention_xformers = False def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): if not is_xformers_available(): raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers", name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" " available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) # proj to q, k, v query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) scale = 1 / math.sqrt(self.channels / self.num_heads) query_proj = self.reshape_heads_to_batch_dim(query_proj) key_proj = self.reshape_heads_to_batch_dim(key_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj) if self._use_memory_efficient_attention_xformers: # Memory efficient attention hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) hidden_states = hidden_states.to(query_proj.dtype) else: attention_scores = torch.baddbmm( torch.empty( query_proj.shape[0], query_proj.shape[1], key_proj.shape[1], dtype=query_proj.dtype, device=query_proj.device, ), query_proj, key_proj.transpose(-1, -2), beta=0, alpha=scale, ) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) hidden_states = torch.bmm(attention_probs, value_proj) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = num_embeds_ada_norm is not None # 1. Self-Attn self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) # 2. Cross-Attn if cross_attention_dim is not None: self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) if cross_attention_dim is not None: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers", name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" " available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers if self.attn2 is not None: self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) if self.only_cross_attention: hidden_states = ( self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states ) else: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states if self.attn2 is not None: # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) hidden_states = ( self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) + hidden_states ) # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states return hidden_states class CrossAttention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.scale = dim_head**-0.5 self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self._slice_size = None self._use_memory_efficient_attention_xformers = False self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) else: self.group_norm = None self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) self.to_out.append(nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") self._slice_size = slice_size def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape encoder_hidden_states = encoder_hidden_states if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) dim = query.shape[-1] query = self.reshape_heads_to_batch_dim(query) if self.added_kv_proj_dim is not None: key = self.to_k(hidden_states) value = self.to_v(hidden_states) encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) else: encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value, attention_mask) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states def _attention(self, query, key, value, attention_mask=None): if self.upcast_attention: query = query.float() key = key.float() attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, key.transpose(-1, -2), beta=0, alpha=self.scale, ) if attention_mask is not None: attention_scores = attention_scores + attention_mask if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) # cast back to the original dtype attention_probs = attention_probs.to(value.dtype) # compute attention output hidden_states = torch.bmm(attention_probs, value) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype ) slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] if self.upcast_attention: query_slice = query_slice.float() key_slice = key_slice.float() attn_slice = torch.baddbmm( torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), query_slice, key_slice.transpose(-1, -2), beta=0, alpha=self.scale, ) if attention_mask is not None: attn_slice = attn_slice + attention_mask[start_idx:end_idx] if self.upcast_softmax: attn_slice = attn_slice.float() attn_slice = attn_slice.softmax(dim=-1) # cast back to the original dtype attn_slice = attn_slice.to(value.dtype) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): # TODO attention_mask query = query.contiguous() key = key.contiguous() value = value.contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states class GELU(nn.Module): r""" GELU activation function """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states # feedforward class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) class ApproximateGELU(nn.Module): """ The approximate form of Gaussian Error Linear Unit (GELU) For more details, see section 2: https://arxiv.org/abs/1606.08415 """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, x): x = self.proj(x) return x * torch.sigmoid(1.702 * x) class AdaLayerNorm(nn.Module): """ Norm layer modified to incorporate timestep embeddings. """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = nn.Embedding(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) def forward(self, x, timestep): emb = self.linear(self.silu(self.emb(timestep))) scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x class DualTransformer2DModel(nn.Module): """ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, ): super().__init__() self.transformers = nn.ModuleList( [ Transformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, ) for _ in range(2) ] ) # Variables that can be set by a pipeline: # The ratio of transformer1 to transformer2's output states to be combined during inference self.mix_ratio = 0.5 # The shape of `encoder_hidden_states` is expected to be # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` self.condition_lengths = [77, 257] # Which transformer to use to encode which condition. # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` self.transformer_index_for_condition = [1, 0] def forward( self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.FloatTensor`, *optional*): Optional attention mask to be applied in CrossAttention return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ input_states = hidden_states encoded_states = [] tokens_start = 0 # attention_mask is not used yet for i in range(2): # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] transformer_index = self.transformer_index_for_condition[i] encoded_state = self.transformers[transformer_index]( input_states, encoder_hidden_states=condition_state, timestep=timestep, return_dict=False, )[0] encoded_states.append(encoded_state - input_states) tokens_start += self.condition_lengths[i] output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) output_states = output_states + input_states if not return_dict: return (output_states,) return Transformer2DModelOutput(sample=output_states) ================================================ FILE: magicanimate/models/resnet.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/guoyww/AnimateDiff # Copyright 2023 The HuggingFace Team. All rights reserved. # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class InflatedConv3d(nn.Conv2d): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x class Upsample3D(nn.Module): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: raise NotImplementedError elif use_conv: self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: raise NotImplementedError # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) hidden_states = self.conv(hidden_states) return hidden_states class Downsample3D(nn.Module): def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: raise NotImplementedError def forward(self, hidden_states): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: raise NotImplementedError assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class ResnetBlock3D(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", time_embedding_norm="default", output_scale_factor=1.0, use_in_shortcut=None, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": time_emb_proj_out_channels = out_channels elif self.time_embedding_norm == "scale_shift": time_emb_proj_out_channels = out_channels * 2 else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) elif non_linearity == "mish": self.nonlinearity = Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, input_tensor, temb): hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class Mish(torch.nn.Module): def forward(self, hidden_states): return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) ================================================ FILE: magicanimate/models/stable_diffusion_controlnet_reference.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from diffusers import StableDiffusionControlNetPipeline from diffusers.models import ControlNetModel from diffusers.models.attention import BasicTransformerBlock from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from diffusers.utils.torch_utils import is_compiled_module, randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import cv2 >>> import torch >>> import numpy as np >>> from PIL import Image >>> from diffusers import UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") >>> # get canny image >>> image = cv2.Canny(np.array(input_image), 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 ).to('cuda:0') >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) >>> result_img = pipe(ref_image=input_image, prompt="1girl", image=canny_image, num_inference_steps=20, reference_attn=True, reference_adain=True).images[0] >>> result_img.show() ``` """ def torch_dfs(model: torch.nn.Module): result = [model] for child in model.children(): result += torch_dfs(child) return result class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline): def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): refimage = refimage.to(device=device, dtype=dtype) # encode the mask image into latents space so we can concatenate it to the latents if isinstance(generator, list): ref_image_latents = [ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(batch_size) ] ref_image_latents = torch.cat(ref_image_latents, dim=0) else: ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) ref_image_latents = self.vae.config.scaling_factor * ref_image_latents # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method if ref_image_latents.shape[0] < batch_size: if not batch_size % ref_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents # aligning device to prevent device errors when concating it with the latent model input ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) return ref_image_latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, attention_auto_machine_weight: float = 1.0, gn_auto_machine_weight: float = 1.0, style_fidelity: float = 0.5, reference_attn: bool = True, reference_adain: bool = True, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. ref_image (`torch.FloatTensor`, `PIL.Image.Image`): The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can also be accepted as an image. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. attention_auto_machine_weight (`float`): Weight of using reference query for self attention's context. If attention_auto_machine_weight=1.0, use reference query for all self attention's context. gn_auto_machine_weight (`float`): Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. style_fidelity (`float`): style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, elif style_fidelity=0.0, prompt more important, else balanced. reference_attn (`bool`): Whether to use reference query for self attention's context. reference_adain (`bool`): Whether to use reference adain. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Preprocess reference image ref_image = self.prepare_image( image=ref_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=prompt_embeds.dtype, ) # 6. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 8. Prepare reference latent variables ref_image_latents = self.prepare_ref_latents( ref_image, batch_size * num_images_per_prompt, prompt_embeds.dtype, device, generator, do_classifier_free_guidance, ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Modify self attention and group norm MODE = "write" uc_mask = ( torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) .type_as(ref_image_latents) .bool() ) def hacked_basic_transformer_inner_forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if self.only_cross_attention: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) else: if MODE == "write": self.bank.append(norm_hidden_states.detach().clone()) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if MODE == "read": if attention_auto_machine_weight > self.attn_weight: attn_output_uc = self.attn1( norm_hidden_states, encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), # attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output_c = attn_output_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: attn_output_c[uc_mask] = self.attn1( norm_hidden_states[uc_mask], encoder_hidden_states=norm_hidden_states[uc_mask], **cross_attention_kwargs, ) attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc self.bank.clear() else: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states def hacked_mid_forward(self, *args, **kwargs): eps = 1e-6 x = self.original_forward(*args, **kwargs) if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append(mean) self.var_bank.append(var) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) var_acc = sum(self.var_bank) / float(len(self.var_bank)) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 x_uc = (((x - mean) / std) * std_acc) + mean_acc x_c = x_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: x_c[uc_mask] = x[uc_mask] x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc self.mean_bank = [] self.var_bank = [] return x def hack_CrossAttnDownBlock2D_forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): eps = 1e-6 # TODO(Patrick, William) - attention mask is not used output_states = () for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask] hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc output_states = output_states + (hidden_states,) if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states def hacked_DownBlock2D_forward(self, hidden_states, temb=None): eps = 1e-6 output_states = () for i, resnet in enumerate(self.resnets): hidden_states = resnet(hidden_states, temb) if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask] hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc output_states = output_states + (hidden_states,) if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states def hacked_CrossAttnUpBlock2D_forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): eps = 1e-6 # TODO(Patrick, William) - attention mask is not used for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask] hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): eps = 1e-6 for i, resnet in enumerate(self.resnets): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if MODE == "write": if gn_auto_machine_weight >= self.gn_weight: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) self.mean_bank.append([mean]) self.var_bank.append([var]) if MODE == "read": if len(self.mean_bank) > 0 and len(self.var_bank) > 0: var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc hidden_states_c = hidden_states_uc.clone() if do_classifier_free_guidance and style_fidelity > 0: hidden_states_c[uc_mask] = hidden_states[uc_mask] hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc if MODE == "read": self.mean_bank = [] self.var_bank = [] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states if reference_attn: attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) for i, module in enumerate(attn_modules): module._original_inner_forward = module.forward module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) module.bank = [] module.attn_weight = float(i) / float(len(attn_modules)) if reference_adain: gn_modules = [self.unet.mid_block] self.unet.mid_block.gn_weight = 0 down_blocks = self.unet.down_blocks for w, module in enumerate(down_blocks): module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) gn_modules.append(module) up_blocks = self.unet.up_blocks for w, module in enumerate(up_blocks): module.gn_weight = float(w) / float(len(up_blocks)) gn_modules.append(module) for i, module in enumerate(gn_modules): if getattr(module, "original_forward", None) is None: module.original_forward = module.forward if i == 0: # mid_block module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) elif isinstance(module, CrossAttnDownBlock2D): module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) elif isinstance(module, DownBlock2D): module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) elif isinstance(module, CrossAttnUpBlock2D): module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) elif isinstance(module, UpBlock2D): module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) module.mean_bank = [] module.var_bank = [] module.gn_weight *= 2 # 11. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # ref only part noise = randn_tensor( ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype ) ref_xt = self.scheduler.add_noise( ref_image_latents, noise, t.reshape( 1, ), ) ref_xt = self.scheduler.scale_model_input(ref_xt, t) MODE = "write" self.unet( ref_xt, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, ) # predict the noise residual MODE = "read" noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: magicanimate/models/unet.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/guoyww/AnimateDiff # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Tuple, Union import os import json import pdb import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) from .resnet import InflatedConv3d logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor class UNet3DConditionModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), mid_block_type: str = "UNetMidBlock3DCrossAttn", up_block_types: Tuple[str] = ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", # Additional use_motion_module = False, motion_module_resolutions = ( 1,2,4,8 ), motion_module_mid_block = False, motion_module_decoder_only = False, motion_module_type = None, motion_module_kwargs = {}, unet_use_cross_frame_attention = None, unet_use_temporal_attention = None, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # input self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock3DCrossAttn": self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the videos self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): res = 2 ** (3 - i) is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_slicable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_slicable_dims(module) num_slicable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_slicable_layers * [1] slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # pre-process sample = self.conv_in(sample) # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples # mid sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) # up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) @classmethod def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) config["_class_name"] = cls.__name__ config["down_block_types"] = [ "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D" ] config["up_block_types"] = [ "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ] from diffusers.utils import WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") state_dict = torch.load(model_file, map_location="cpu") m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") return model ================================================ FILE: magicanimate/models/unet_3d_blocks.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/guoyww/AnimateDiff # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from .attention import Transformer3DModel from .resnet import Downsample3D, ResnetBlock3D, Upsample3D from .motion_module import get_motion_module def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] motion_modules = [] for _ in range(num_layers): if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, in_channels // attn_num_head_channels, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=in_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states hidden_states = resnet(hidden_states, temb) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): output_states = () for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, )[0] if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] motion_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () for resnet, motion_module in zip(self.resnets, self.motion_modules): if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None, attention_mask=None, ): for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, )[0] if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock3D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] motion_modules = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): for resnet, motion_module in zip(self.resnets, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states ================================================ FILE: magicanimate/models/unet_controlnet.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Tuple, Union import os import json import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from magicanimate.models.unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) from .resnet import InflatedConv3d logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor class UNet3DConditionModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), mid_block_type: str = "UNetMidBlock3DCrossAttn", up_block_types: Tuple[str] = ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", # Additional use_motion_module = False, motion_module_resolutions = ( 1,2,4,8 ), motion_module_mid_block = False, motion_module_decoder_only = False, motion_module_type = None, motion_module_kwargs = {}, unet_use_cross_frame_attention = None, unet_use_temporal_attention = None, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # input self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock3DCrossAttn": self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the videos self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): res = 2 ** (3 - i) is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_slicable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_slicable_dims(module) num_slicable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_slicable_layers * [1] slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, # for controlnet down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # pre-process sample = self.conv_in(sample) # down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # mid sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) if is_controlnet: sample = sample + mid_block_additional_residual # up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) @classmethod def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) config["_class_name"] = cls.__name__ config["down_block_types"] = [ "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D" ] config["up_block_types"] = [ "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ] # config["mid_block_type"] = "UNetMidBlock3DCrossAttn" from diffusers.utils import WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") state_dict = torch.load(model_file, map_location="cpu") m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") return model ================================================ FILE: magicanimate/pipelines/animation.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import argparse import datetime import inspect import os import random import numpy as np from PIL import Image from omegaconf import OmegaConf from collections import OrderedDict import torch import torch.distributed as dist from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler from tqdm import tqdm from transformers import CLIPTextModel, CLIPTokenizer from magicanimate.models.unet_controlnet import UNet3DConditionModel from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid from magicanimate.utils.dist_tools import distributed_init from accelerate.utils import set_seed from magicanimate.utils.videoreader import VideoReader from einops import rearrange from pathlib import Path def main(args): *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) config = OmegaConf.load(args.config) # Initialize distributed training device = torch.device(f"cuda:{args.rank}") dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} if config.savename is None: time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" else: savedir = f"samples/{config.savename}" if args.dist: dist.broadcast_object_list([savedir], 0) dist.barrier() if args.rank == 0: os.makedirs(savedir, exist_ok=True) inference_config = OmegaConf.load(config.inference_config) motion_module = config.motion_module ### >>> create animation pipeline >>> ### tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") if config.pretrained_unet_path: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) else: unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) if config.pretrained_vae_path is not None: vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) else: vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") ### Load controlnet controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) unet.enable_xformers_memory_efficient_attention() appearance_encoder.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() vae.to(torch.float16) unet.to(torch.float16) text_encoder.to(torch.float16) appearance_encoder.to(torch.float16) controlnet.to(torch.float16) pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), # NOTE: UniPCMultistepScheduler ) # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(motion_module, map_location="cpu") if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict try: # extra steps for self-trained models state_dict = OrderedDict() for key in motion_module_state_dict.keys(): if key.startswith("module."): _key = key.split("module.")[-1] state_dict[_key] = motion_module_state_dict[key] else: state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): _key = key.split('unet.')[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] missing, unexpected = unet.load_state_dict(_tmp_, strict=False) assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict pipeline.to(device) ### <<< create validation pipeline <<< ### random_seeds = config.get("seed", [-1]) random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds # input test videos (either source video/ conditions) test_videos = config.video_path source_images = config.source_image num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) # read size, step from yaml file sizes = [config.size] * len(test_videos) steps = [config.S] * len(test_videos) config.random_seed = [] prompt = n_prompt = "" for idx, (source_image, test_video, random_seed, size, step) in tqdm( enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), total=len(test_videos), disable=(args.rank!=0) ): samples_per_video = [] samples_per_clip = [] # manually set random seed for reproduction if random_seed != -1: torch.manual_seed(random_seed) set_seed(random_seed) else: torch.seed() config.random_seed.append(torch.initial_seed()) if test_video.endswith('.mp4'): control = VideoReader(test_video).read() if control[0].shape[0] != size: control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] if config.max_length is not None: control = control[config.offset: (config.offset+config.max_length)] control = np.array(control) if source_image.endswith(".mp4"): source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) else: source_image = np.array(Image.open(source_image).resize((size, size))) H, W, C = source_image.shape print(f"current seed: {torch.initial_seed()}") init_latents = None # print(f"sampling {prompt} ...") original_length = control.shape[0] if control.shape[0] % config.L > 0: control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') generator = torch.Generator(device=torch.device("cuda:0")) generator.manual_seed(torch.initial_seed()) sample = pipeline( prompt, negative_prompt = n_prompt, num_inference_steps = config.steps, guidance_scale = config.guidance_scale, width = W, height = H, video_length = len(control), controlnet_condition = control, init_latents = init_latents, generator = generator, num_actual_inference_steps = num_actual_inference_steps, appearance_encoder = appearance_encoder, reference_control_writer = reference_control_writer, reference_control_reader = reference_control_reader, source_image = source_image, **dist_kwargs, ).videos if args.rank == 0: source_images = np.array([source_image] * original_length) source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 samples_per_video.append(source_images) control = control / 255.0 control = rearrange(control, "t h w c -> 1 c t h w") control = torch.from_numpy(control) samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) samples_per_video = torch.cat(samples_per_video) video_name = os.path.basename(test_video)[:-4] source_name = os.path.basename(config.source_image[idx]).split(".")[0] save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") if config.save_individual_videos: save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") if args.dist: dist.barrier() if args.rank == 0: OmegaConf.save(config, f"{savedir}/config.yaml") def distributed_main(device_id, args): args.rank = device_id args.device_id = device_id if torch.cuda.is_available(): torch.cuda.set_device(args.device_id) torch.cuda.init() distributed_init(args) main(args) def run(args): if args.dist: args.world_size = max(1, torch.cuda.device_count()) assert args.world_size <= torch.cuda.device_count() if args.world_size > 0 and torch.cuda.device_count() > 1: port = random.randint(10000, 20000) args.init_method = f"tcp://localhost:{port}" torch.multiprocessing.spawn( fn=distributed_main, args=(args,), nprocs=args.world_size, ) else: main(args) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--dist", action="store_true", required=False) parser.add_argument("--rank", type=int, default=0, required=False) parser.add_argument("--world_size", type=int, default=1, required=False) args = parser.parse_args() run(args) ================================================ FILE: magicanimate/pipelines/context.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main import numpy as np from typing import Callable, Optional, List def ordered_halving(val): bin_str = f"{val:064b}" bin_flip = bin_str[::-1] as_int = int(bin_flip, 2) return as_int / (1 << 64) def uniform( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): if num_frames <= context_size: yield list(range(num_frames)) return context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) for context_step in 1 << np.arange(context_stride): pad = int(round(num_frames * ordered_halving(step))) for j in range( int(ordered_halving(step) * context_step) + pad, num_frames + pad + (0 if closed_loop else -context_overlap), (context_size * context_step - context_overlap), ): yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] def get_context_scheduler(name: str) -> Callable: if name == "uniform": return uniform else: raise ValueError(f"Unknown context_overlap policy {name}") def get_total_steps( scheduler, timesteps: List[int], num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): return sum( len( list( scheduler( i, num_steps, num_frames, context_size, context_stride, context_overlap, ) ) ) for i in range(len(timesteps)) ) ================================================ FILE: magicanimate/pipelines/pipeline_animation.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ TODO: 1. support multi-controlnet 2. [DONE] support DDIM inversion 3. support Prompt-to-prompt """ import inspect, math from typing import Callable, List, Optional, Union from dataclasses import dataclass from PIL import Image import numpy as np import torch import torch.distributed as dist from tqdm import tqdm from diffusers.utils import is_accelerate_available from packaging import version from transformers import CLIPTextModel, CLIPTokenizer from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from diffusers.utils import deprecate, logging, BaseOutput from einops import rearrange from magicanimate.models.unet_controlnet import UNet3DConditionModel from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl from magicanimate.pipelines.context import ( get_context_scheduler, get_total_steps ) from magicanimate.utils.util import get_tensor_interpolation_method logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class AnimationPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] class AnimationPipeline(DiffusionPipeline): _optional_components = [] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, controlnet: ControlNetModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def enable_vae_slicing(self): self.vae.enable_slicing() def disable_vae_slicing(self): self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings def decode_latents(self, latents, rank, decoder_consistency=None): video_length = latents.shape[2] latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") # video = self.vae.decode(latents).sample video = [] for frame_idx in tqdm(range(latents.shape[0]), disable=(rank!=0)): if decoder_consistency is not None: video.append(decoder_consistency(latents[frame_idx:frame_idx+1])) else: video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) video = torch.cat(video) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 video = video.cpu().float().numpy() return video def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, prompt, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, clip_length=16): shape = (batch_size, num_channels_latents, clip_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: rand_device = "cpu" if device.type == "mps" else device if isinstance(generator, list): latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) latents = latents.repeat(1, 1, video_length//clip_length, 1, 1) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_condition(self, condition, num_videos_per_prompt, device, dtype, do_classifier_free_guidance): # prepare conditions for controlnet condition = torch.from_numpy(condition.copy()).to(device=device, dtype=dtype) / 255.0 condition = torch.stack([condition for _ in range(num_videos_per_prompt)], dim=0) condition = rearrange(condition, 'b f h w c -> (b f) c h w').clone() if do_classifier_free_guidance: condition = torch.cat([condition] * 2) return condition def next_step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0., verbose=False ): """ Inverse sampling for DDIM Inversion """ if verbose: print("timestep: ", timestep) next_step = timestep timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir return x_next, pred_x0 @torch.no_grad() def images2latents(self, images, dtype): """ Convert RGB image to VAE latents """ device = self._execution_device images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1 images = rearrange(images, "f h w c -> f c h w").to(device) latents = [] for frame_idx in range(images.shape[0]): latents.append(self.vae.encode(images[frame_idx:frame_idx+1])['latent_dist'].mean * 0.18215) latents = torch.cat(latents) return latents @torch.no_grad() def invert( self, image: torch.Tensor, prompt, num_inference_steps=20, num_actual_inference_steps=10, eta=0.0, return_intermediates=False, **kwargs): """ Adapted from: https://github.com/Yujun-Shi/DragDiffusion/blob/main/drag_pipeline.py#L440 invert a real image into noise map with determinisc DDIM inversion """ device = self._execution_device batch_size = image.shape[0] if isinstance(prompt, list): if batch_size == 1: image = image.expand(len(prompt), -1, -1, -1) elif isinstance(prompt, str): if batch_size > 1: prompt = [prompt] * batch_size # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] print("input text embeddings :", text_embeddings.shape) # define initial latents latents = self.images2latents(image) print("latents shape: ", latents.shape) # interative sampling self.scheduler.set_timesteps(num_inference_steps) print("Valid timesteps: ", reversed(self.scheduler.timesteps)) latents_list = [latents] pred_x0_list = [latents] for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): if num_actual_inference_steps is not None and i >= num_actual_inference_steps: continue model_inputs = latents # predict the noise # NOTE: the u-net here is UNet3D, therefore the model_inputs need to be of shape (b c f h w) model_inputs = rearrange(model_inputs, "f c h w -> 1 c f h w") noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample noise_pred = rearrange(noise_pred, "b c f h w -> (b f) c h w") # compute the previous noise sample x_t-1 -> x_t latents, pred_x0 = self.next_step(noise_pred, t, latents) latents_list.append(latents) pred_x0_list.append(pred_x0) if return_intermediates: # return the intermediate laters during inversion return latents, latents_list return latents def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, device ): if interpolation_factor < 2: return latents new_latents = torch.zeros( (latents.shape[0],latents.shape[1],((latents.shape[2]-1) * interpolation_factor)+1, latents.shape[3],latents.shape[4]), device=latents.device, dtype=latents.dtype, ) org_video_length = latents.shape[2] rate = [i/interpolation_factor for i in range(interpolation_factor)][1:] new_index = 0 v0 = None v1 = None for i0,i1 in zip( range( org_video_length ),range( org_video_length )[1:] ): v0 = latents[:,:,i0,:,:] v1 = latents[:,:,i1,:,:] new_latents[:,:,new_index,:,:] = v0 new_index += 1 for f in rate: v = get_tensor_interpolation_method()(v0.to(device=device),v1.to(device=device),f) new_latents[:,:,new_index,:,:] = v.to(latents.device) new_index += 1 new_latents[:,:,new_index,:,:] = v1 new_index += 1 return new_latents def select_controlnet_res_samples(self, controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f): _down_block_res_samples = [] _mid_block_res_sample = [] for i in np.concatenate(np.array(context)): _down_block_res_samples.append(controlnet_res_samples_cache_dict[i][0]) _mid_block_res_sample.append(controlnet_res_samples_cache_dict[i][1]) down_block_res_samples = [[] for _ in range(len(controlnet_res_samples_cache_dict[i][0]))] for res_t in _down_block_res_samples: for i, res in enumerate(res_t): down_block_res_samples[i].append(res) down_block_res_samples = [torch.cat(res) for res in down_block_res_samples] mid_block_res_sample = torch.cat(_mid_block_res_sample) # reshape controlnet output to match the unet3d inputs b = b // 2 if do_classifier_free_guidance else b _down_block_res_samples = [] for sample in down_block_res_samples: sample = rearrange(sample, '(b f) c h w -> b c f h w', b=b, f=f) if do_classifier_free_guidance: sample = sample.repeat(2, 1, 1, 1, 1) _down_block_res_samples.append(sample) down_block_res_samples = _down_block_res_samples mid_block_res_sample = rearrange(mid_block_res_sample, '(b f) c h w -> b c f h w', b=b, f=f) if do_classifier_free_guidance: mid_block_res_sample = mid_block_res_sample.repeat(2, 1, 1, 1, 1) return down_block_res_samples, mid_block_res_sample @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, controlnet_condition: list = None, controlnet_conditioning_scale: float = 1.0, context_frames: int = 16, context_stride: int = 1, context_overlap: int = 4, context_batch_size: int = 1, context_schedule: str = "uniform", init_latents: Optional[torch.FloatTensor] = None, num_actual_inference_steps: Optional[int] = None, appearance_encoder = None, reference_control_writer = None, reference_control_reader = None, source_image: str = None, decoder_consistency = None, **kwargs, ): """ New args: - controlnet_condition : condition map (e.g., depth, canny, keypoints) for controlnet - controlnet_conditioning_scale : conditioning scale for controlnet - init_latents : initial latents to begin with (used along with invert()) - num_actual_inference_steps : number of actual inference steps (while total steps is num_inference_steps) """ controlnet = self.controlnet # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters # batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if latents is not None: batch_size = latents.shape[0] if isinstance(prompt, list): batch_size = len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size if negative_prompt is not None: negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size text_embeddings = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) text_embeddings = torch.cat([text_embeddings] * context_batch_size) reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', batch_size=context_batch_size) reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', batch_size=context_batch_size) is_dist_initialized = kwargs.get("dist", False) rank = kwargs.get("rank", 0) world_size = kwargs.get("world_size", 1) # Prepare video assert num_videos_per_prompt == 1 # FIXME: verify if num_videos_per_prompt > 1 works assert batch_size == 1 # FIXME: verify if batch_size > 1 works control = self.prepare_condition( condition=controlnet_condition, device=device, dtype=controlnet.dtype, num_videos_per_prompt=num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, ) controlnet_uncond_images, controlnet_cond_images = control.chunk(2) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Prepare latent variables if init_latents is not None: latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length) else: num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, video_length, height, width, text_embeddings.dtype, device, generator, latents, ) latents_dtype = latents.dtype # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Prepare text embeddings for controlnet controlnet_text_embeddings = text_embeddings.repeat_interleave(video_length, 0) _, controlnet_text_embeddings_c = controlnet_text_embeddings.chunk(2) controlnet_res_samples_cache_dict = {i:None for i in range(video_length)} # For img2img setting if num_actual_inference_steps is None: num_actual_inference_steps = num_inference_steps if isinstance(source_image, str): ref_image_latents = self.images2latents(np.array(Image.open(source_image).resize((width, height)))[None, :], latents_dtype).cuda() elif isinstance(source_image, np.ndarray): ref_image_latents = self.images2latents(source_image[None, :], latents_dtype).cuda() context_scheduler = get_context_scheduler(context_schedule) # Denoising loop for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank!=0)): if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps: continue noise_pred = torch.zeros( (latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:]), device=latents.device, dtype=latents.dtype, ) counter = torch.zeros( (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype ) appearance_encoder( ref_image_latents.repeat(context_batch_size * (2 if do_classifier_free_guidance else 1), 1, 1, 1), t, encoder_hidden_states=text_embeddings, return_dict=False, ) context_queue = list(context_scheduler( 0, num_inference_steps, latents.shape[2], context_frames, context_stride, 0 )) num_context_batches = math.ceil(len(context_queue) / context_batch_size) for i in range(num_context_batches): context = context_queue[i*context_batch_size: (i+1)*context_batch_size] # expand the latents if we are doing classifier free guidance controlnet_latent_input = ( torch.cat([latents[:, :, c] for c in context]) .to(device) ) controlnet_latent_input = self.scheduler.scale_model_input(controlnet_latent_input, t) # prepare inputs for controlnet b, c, f, h, w = controlnet_latent_input.shape controlnet_latent_input = rearrange(controlnet_latent_input, "b c f h w -> (b f) c h w") # controlnet inference down_block_res_samples, mid_block_res_sample = self.controlnet( controlnet_latent_input, t, encoder_hidden_states=torch.cat([controlnet_text_embeddings_c[c] for c in context]), controlnet_cond=torch.cat([controlnet_cond_images[c] for c in context]), conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) for j, k in enumerate(np.concatenate(np.array(context))): controlnet_res_samples_cache_dict[k] = ([sample[j:j+1] for sample in down_block_res_samples], mid_block_res_sample[j:j+1]) context_queue = list(context_scheduler( 0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap )) num_context_batches = math.ceil(len(context_queue) / context_batch_size) global_context = [] for i in range(num_context_batches): global_context.append(context_queue[i*context_batch_size: (i+1)*context_batch_size]) for context in global_context[rank::world_size]: # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents[:, :, c] for c in context]) .to(device) .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1) ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) b, c, f, h, w = latent_model_input.shape down_block_res_samples, mid_block_res_sample = self.select_controlnet_res_samples( controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f ) reference_control_reader.update(reference_control_writer) # predict the noise residual pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings[:b], down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] reference_control_reader.clear() pred_uc, pred_c = pred.chunk(2) pred = torch.cat([pred_uc.unsqueeze(0), pred_c.unsqueeze(0)]) for j, c in enumerate(context): noise_pred[:, :, c] = noise_pred[:, :, c] + pred[:, j] counter[:, :, c] = counter[:, :, c] + 1 if is_dist_initialized: noise_pred_gathered = [torch.zeros_like(noise_pred) for _ in range(world_size)] if rank == 0: dist.gather(tensor=noise_pred, gather_list=noise_pred_gathered, dst=0) else: dist.gather(tensor=noise_pred, gather_list=[], dst=0) dist.barrier() if rank == 0: for k in range(1, world_size): for context in global_context[k::world_size]: for j, c in enumerate(context): noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_gathered[k][:, :, c] counter[:, :, c] = counter[:, :, c] + 1 # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if is_dist_initialized: dist.broadcast(latents, 0) dist.barrier() reference_control_writer.clear() interpolation_factor = 1 latents = self.interpolate_latents(latents, interpolation_factor, device) # Post-processing video = self.decode_latents(latents, rank, decoder_consistency=decoder_consistency) if is_dist_initialized: dist.barrier() # Convert to tensor if output_type == "tensor": video = torch.from_numpy(video) if not return_dict: return video return AnimationPipelineOutput(videos=video) ================================================ FILE: magicanimate/utils/dist_tools.py ================================================ # Copyright 2023 ByteDance and/or its affiliates. # # Copyright (2023) MagicAnimate Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import os import socket import warnings import torch from torch import distributed as dist def distributed_init(args): if dist.is_initialized(): warnings.warn("Distributed is already initialized, cannot initialize twice!") args.rank = dist.get_rank() else: print( f"Distributed Init (Rank {args.rank}): " f"{args.init_method}" ) dist.init_process_group( backend='nccl', init_method=args.init_method, world_size=args.world_size, rank=args.rank, ) print( f"Initialized Host {socket.gethostname()} as Rank " f"{args.rank}" ) if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ: # Set for onboxdataloader support split = args.init_method.split("//") assert len(split) == 2, ( "host url for distributed should be split by '//' " + "into exactly two elements" ) split = split[1].split(":") assert ( len(split) == 2 ), "host url should be of the form :" os.environ["MASTER_ADDR"] = split[0] os.environ["MASTER_PORT"] = split[1] # perform a dummy all-reduce to initialize the NCCL communicator dist.all_reduce(torch.zeros(1).cuda()) suppress_output(is_master()) args.rank = dist.get_rank() return args.rank def get_rank(): if not dist.is_available(): return 0 if not dist.is_nccl_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def is_master(): return get_rank() == 0 def synchronize(): if dist.is_initialized(): dist.barrier() def suppress_output(is_master): """Suppress printing on the current device. Force printing with `force=True`.""" import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print import warnings builtin_warn = warnings.warn def warn(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_warn(*args, **kwargs) # Log warnings only once warnings.warn = warn warnings.simplefilter("once", UserWarning) ================================================ FILE: magicanimate/utils/util.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/guoyww/AnimateDiff import os import imageio import numpy as np import torch import torchvision from PIL import Image from typing import Union from tqdm import tqdm from einops import rearrange def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) def save_images_grid(images: torch.Tensor, path: str): assert images.shape[2] == 1 # no time dimension images = images.squeeze(2) grid = torchvision.utils.make_grid(images) grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8) os.makedirs(os.path.dirname(path), exist_ok=True) Image.fromarray(grid).save(path) # DDIM Inversion @torch.no_grad() def init_prompt(prompt, pipeline): uncond_input = pipeline.tokenizer( [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] text_input = pipeline.tokenizer( [prompt], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] context = torch.cat([uncond_embeddings, text_embeddings]) return context def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): timestep, next_timestep = min( timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] beta_prod_t = 1 - alpha_prod_t next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction return next_sample def get_noise_pred_single(latents, t, context, unet): noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] return noise_pred @torch.no_grad() def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): context = init_prompt(prompt, pipeline) uncond_embeddings, cond_embeddings = context.chunk(2) all_latent = [latent] latent = latent.clone().detach() for i in tqdm(range(num_inv_steps)): t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) latent = next_step(noise_pred, t, latent, ddim_scheduler) all_latent.append(latent) return all_latent @torch.no_grad() def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) return ddim_latents def video2images(path, step=4, length=16, start=0): reader = imageio.get_reader(path) frames = [] for frame in reader: frames.append(np.array(frame)) frames = frames[start::step][:length] return frames def images2video(video, path, fps=8): imageio.mimsave(path, video, fps=fps) return tensor_interpolation = None def get_tensor_interpolation_method(): return tensor_interpolation def set_tensor_interpolation_method(is_slerp): global tensor_interpolation tensor_interpolation = slerp if is_slerp else linear def linear(v1, v2, t): return (1.0 - t) * v1 + t * v2 def slerp( v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 ) -> torch.Tensor: u0 = v0 / v0.norm() u1 = v1 / v1.norm() dot = (u0 * u1).sum() if dot.abs() > DOT_THRESHOLD: #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') return (1.0 - t) * v0 + t * v1 omega = dot.acos() return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() ================================================ FILE: magicanimate/utils/videoreader.py ================================================ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Copyright 2022 ByteDance and/or its affiliates. # # Copyright (2022) PV3D Authors # # ByteDance, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. import av, gc import torch import warnings import numpy as np _CALLED_TIMES = 0 _GC_COLLECTION_INTERVAL = 20 # remove warnings av.logging.set_level(av.logging.ERROR) class VideoReader(): """ Simple wrapper around PyAV that exposes a few useful functions for dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries. Acknowledgement: Codes are borrowed from Bruno Korbar """ def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False): """ Arguments: video_path (str): path or byte of the video to be loaded """ self.container = av.open(video) self.num_frames = num_frames self.bi_frame = bi_frame self.resampler = None if audio_resample_rate is not None: self.resampler = av.AudioResampler(rate=audio_resample_rate) if self.container.streams.video: # enable multi-threaded video decoding if decode_lossy: warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) self.container.streams.video[0].thread_type = 'AUTO' self.video_stream = self.container.streams.video[0] else: self.video_stream = None self.fps = self._get_video_frame_rate() def seek(self, pts, backward=True, any_frame=False): stream = self.video_stream self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream) def _occasional_gc(self): # there are a lot of reference cycles in PyAV, so need to manually call # the garbage collector from time to time global _CALLED_TIMES, _GC_COLLECTION_INTERVAL _CALLED_TIMES += 1 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: gc.collect() def _read_video(self, offset): self._occasional_gc() pts = self.container.duration * offset time_ = pts / float(av.time_base) self.container.seek(int(pts)) video_frames = [] count = 0 for _, frame in enumerate(self._iter_frames()): if frame.pts * frame.time_base >= time_: video_frames.append(frame) if count >= self.num_frames - 1: break count += 1 return video_frames def _iter_frames(self): for packet in self.container.demux(self.video_stream): for frame in packet.decode(): yield frame def _compute_video_stats(self): if self.video_stream is None or self.container is None: return 0 num_of_frames = self.container.streams.video[0].frames if num_of_frames == 0: num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base) self.seek(0, backward=False) count = 0 time_base = 512 for p in self.container.decode(video=0): count = count + 1 if count == 1: start_pts = p.pts elif count == 2: time_base = p.pts - start_pts break return start_pts, time_base, num_of_frames def _get_video_frame_rate(self): return float(self.container.streams.video[0].guessed_rate) def sample(self, debug=False): if self.container is None: raise RuntimeError('video stream not found') sample = dict() _, _, total_num_frames = self._compute_video_stats() offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item() video_frames = self._read_video(offset/total_num_frames) video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) sample["frames"] = video_frames sample["frame_idx"] = [offset] if self.bi_frame: frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)] frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)] frames.sort() video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]]) Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)] sample["frames"] = video_frames sample["real_t"] = torch.tensor(Ts, dtype=torch.float32) sample["frame_idx"] = [offset+min(frames), offset+max(frames)] return sample return sample def read_frames(self, frame_indices): self.num_frames = frame_indices[1] - frame_indices[0] video_frames = self._read_video(frame_indices[0]/self.get_num_frames()) video_frames = np.array([ np.uint8(video_frames[0].to_rgb().to_ndarray()), np.uint8(video_frames[-1].to_rgb().to_ndarray()) ]) return video_frames def read(self): video_frames = self._read_video(0) video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) return video_frames def get_num_frames(self): _, _, total_num_frames = self._compute_video_stats() return total_num_frames ================================================ FILE: requirements.txt ================================================ absl-py==1.4.0 accelerate==0.22.0 aiofiles==23.2.1 aiohttp==3.8.5 aiosignal==1.3.1 altair==5.0.1 annotated-types==0.5.0 antlr4-python3-runtime==4.9.3 anyio==3.7.1 async-timeout==4.0.3 attrs==23.1.0 cachetools==5.3.1 certifi==2023.7.22 charset-normalizer==3.2.0 click==8.1.7 cmake==3.27.2 contourpy==1.1.0 cycler==0.11.0 datasets==2.14.4 dill==0.3.7 einops==0.6.1 exceptiongroup==1.1.3 fastapi==0.103.0 ffmpy==0.3.1 filelock==3.12.2 fonttools==4.42.1 frozenlist==1.4.0 fsspec==2023.6.0 google-auth==2.22.0 google-auth-oauthlib==1.0.0 gradio==3.41.2 gradio-client==0.5.0 grpcio==1.57.0 h11==0.14.0 httpcore==0.17.3 httpx==0.24.1 huggingface-hub==0.16.4 idna==3.4 importlib-metadata==6.8.0 importlib-resources==6.0.1 jinja2==3.1.2 joblib==1.3.2 jsonschema==4.19.0 jsonschema-specifications==2023.7.1 kiwisolver==1.4.5 lightning-utilities==0.9.0 lit==16.0.6 markdown==3.4.4 markupsafe==2.1.3 matplotlib==3.7.2 mpmath==1.3.0 multidict==6.0.4 multiprocess==0.70.15 networkx==3.1 numpy==1.24.4 nvidia-cublas-cu11==11.10.3.66 nvidia-cuda-cupti-cu11==11.7.101 nvidia-cuda-nvrtc-cu11==11.7.99 nvidia-cuda-runtime-cu11==11.7.99 nvidia-cudnn-cu11==8.5.0.96 nvidia-cufft-cu11==10.9.0.58 nvidia-curand-cu11==10.2.10.91 nvidia-cusolver-cu11==11.4.0.1 nvidia-cusparse-cu11==11.7.4.91 nvidia-nccl-cu11==2.14.3 nvidia-nvtx-cu11==11.7.91 oauthlib==3.2.2 omegaconf==2.3.0 opencv-python==4.8.0.76 orjson==3.9.5 pandas==2.0.3 pillow==9.5.0 pkgutil-resolve-name==1.3.10 protobuf==4.24.2 psutil==5.9.5 pyarrow==13.0.0 pyasn1==0.5.0 pyasn1-modules==0.3.0 pydantic==2.3.0 pydantic-core==2.6.3 pydub==0.25.1 pyparsing==3.0.9 python-multipart==0.0.6 pytorch-lightning==2.0.7 pytz==2023.3 pyyaml==6.0.1 referencing==0.30.2 regex==2023.8.8 requests==2.31.0 requests-oauthlib==1.3.1 rpds-py==0.9.2 rsa==4.9 safetensors==0.3.3 semantic-version==2.10.0 sniffio==1.3.0 starlette==0.27.0 sympy==1.12 tensorboard==2.14.0 tensorboard-data-server==0.7.1 tokenizers==0.13.3 toolz==0.12.0 torchmetrics==1.1.0 tqdm==4.66.1 transformers==4.32.0 triton==2.0.0 tzdata==2023.3 urllib3==1.26.16 uvicorn==0.23.2 websockets==11.0.3 werkzeug==2.3.7 xxhash==3.3.0 yarl==1.9.2 zipp==3.16.2 decord imageio==2.9.0 imageio-ffmpeg==0.4.3 timm scipy scikit-image av imgaug lpips ffmpeg-python torch==2.0.1 torchvision==0.15.2 xformers==0.0.22 diffusers==0.21.4 ================================================ FILE: scripts/animate.sh ================================================ python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml ================================================ FILE: scripts/animate_dist.sh ================================================ python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml --dist