Repository: Bujiazi/MotionClone Branch: main Commit: 8b3d937f005a Files: 36 Total size: 289.7 KB Directory structure: gitextract_e50318u3/ ├── README.md ├── configs/ │ ├── i2v_rgb.jsonl │ ├── i2v_rgb.yaml │ ├── i2v_sketch.jsonl │ ├── i2v_sketch.yaml │ ├── model_config/ │ │ ├── inference-v1.yaml │ │ ├── inference-v2.yaml │ │ ├── inference-v3.yaml │ │ ├── model_config copy.yaml │ │ ├── model_config.yaml │ │ └── model_config_public.yaml │ ├── sparsectrl/ │ │ ├── image_condition.yaml │ │ └── latent_condition.yaml │ ├── t2v_camera.jsonl │ ├── t2v_camera.yaml │ ├── t2v_object.jsonl │ └── t2v_object.yaml ├── environment.yaml ├── generated_videos/ │ └── inference_config.json ├── i2v_video_sample.py ├── models/ │ └── Motion_Module/ │ └── Put motion module checkpoints here.txt ├── motionclone/ │ ├── models/ │ │ ├── attention.py │ │ ├── motion_module.py │ │ ├── resnet.py │ │ ├── scheduler.py │ │ ├── sparse_controlnet.py │ │ ├── unet.py │ │ └── unet_blocks.py │ ├── pipelines/ │ │ └── pipeline_animation.py │ └── utils/ │ ├── conv_layer.py │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ ├── motionclone_functions.py │ ├── util.py │ └── xformer_attention.py └── t2v_video_sample.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # MotionClone This repository is the official implementation of [MotionClone](https://arxiv.org/abs/2406.05338). It is a **training-free framework** that enables motion cloning from a reference video for controllable video generation, **without cumbersome video inversion processes**.
Click for the full abstract of MotionClone > Motion-based controllable video generation offers the potential for creating captivating visual content. Existing methods typically necessitate model training to encode particular motion cues or incorporate fine-tuning to inject certain motion patterns, resulting in limited flexibility and generalization. In this work, we propose **MotionClone** a training-free framework that enables motion cloning from reference videos to versatile motion-controlled video generation, including text-to-video and image-to-video. Based on the observation that the dominant components in temporal-attention maps drive motion synthesis, while the rest mainly capture noisy or very subtle motions, MotionClone utilizes sparse temporal attention weights as motion representations for motion guidance, facilitating diverse motion transfer across varying scenarios. Meanwhile, MotionClone allows for the direct extraction of motion representation through a single denoising step, bypassing the cumbersome inversion processes and thus promoting both efficiency and flexibility. Extensive experiments demonstrate that MotionClone exhibits proficiency in both global camera motion and local object motion, with notable superiority in terms of motion fidelity, textual alignment, and temporal consistency.
**[MotionClone: Training-Free Motion Cloning for Controllable Video Generation](https://arxiv.org/abs/2406.05338)**
[Pengyang Ling*](https://github.com/LPengYang/), [Jiazi Bu*](https://github.com/Bujiazi/), [Pan Zhang](https://panzhang0212.github.io/), [Xiaoyi Dong](https://scholar.google.com/citations?user=FscToE0AAAAJ&hl=en/), [Yuhang Zang](https://yuhangzang.github.io/), [Tong Wu](https://wutong16.github.io/), [Huaian Chen](https://scholar.google.com.hk/citations?hl=zh-CN&user=D6ol9XkAAAAJ), [Jiaqi Wang](https://myownskyw7.github.io/), [Yi Jin](https://scholar.google.ca/citations?hl=en&user=mAJ1dCYAAAAJ) (*Equal Contribution)(Corresponding Author) [![arXiv](https://img.shields.io/badge/arXiv-2406.05338-b31b1b.svg)](https://arxiv.org/abs/2406.05338) [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://bujiazi.github.io/motionclone.github.io/) ![](https://img.shields.io/github/stars/LPengYang/MotionClone?style=social) ## Demo [![]](https://github.com/user-attachments/assets/d1f1c753-f192-455b-9779-94c925e51aaa) ## 🖋 News - The latest version of our paper (**v4**) is available on arXiv! (10.08) - The latest version of our paper (**v3**) is available on arXiv! (7.2) - Code released! (6.29) ## 🏗️ Todo - [x] We have updated the latest version of MotionCloning, which performs motion transfer **without video inversion** and supports **image-to-video and sketch-to-video**. - [x] Release the MotionClone code (We have released **the first version** of our code and will continue to optimize it. We welcome any questions or issues you may have and will address them promptly.) - [x] Release paper ## 📚 Gallery We show more results in the [Project Page](https://bujiazi.github.io/motionclone.github.io/). ## 🚀 Method Overview ### Feature visualization
### Pipeline
MotionClone utilizes sparse temporal attention weights as motion representations for motion guidance, facilitating diverse motion transfer across varying scenarios. Meanwhile, MotionClone allows for the direct extraction of motion representation through a single denoising step, bypassing the cumbersome inversion processes and thus promoting both efficiency and flexibility. ## 🔧 Installations (python==3.11.3 recommended) ### Setup repository and conda environment ``` git clone https://github.com/Bujiazi/MotionClone.git cd MotionClone conda env create -f environment.yaml conda activate motionclone ``` ## 🔑 Pretrained Model Preparations ### Download Stable Diffusion V1.5 ``` git lfs install git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/ ``` After downloading Stable Diffusion, save them to `models/StableDiffusion`. ### Prepare Community Models Manually download the community `.safetensors` models from [RealisticVision V5.1](https://civitai.com/models/4201?modelVersionId=130072) and save them to `models/DreamBooth_LoRA`. ### Prepare AnimateDiff Motion Modules Manually download the AnimateDiff modules from [AnimateDiff](https://github.com/guoyww/AnimateDiff), we recommend [`v3_adapter_sd_v15.ckpt`](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt) and [`v3_sd15_mm.ckpt.ckpt`](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt). Save the modules to `models/Motion_Module`. ### Prepare SparseCtrl for image-to-video and sketch-to-video Manually download "v3_sd15_sparsectrl_rgb.ckpt" and "v3_sd15_sparsectrl_scribble.ckpt" from [AnimateDiff](https://huggingface.co/guoyww/animatediff/tree/main). Save the modules to `models/SparseCtrl`. ## 🎈 Quick Start ### Perform Text-to-video generation with customized camera motion ``` python t2v_video_sample.py --inference_config "configs/t2v_camera.yaml" --examples "configs/t2v_camera.jsonl" ``` ### Perform Text-to-video generation with customized object motion ``` python t2v_video_sample.py --inference_config "configs/t2v_object.yaml" --examples "configs/t2v_object.jsonl" ``` ### Combine motion cloning with sketch-to-video ``` python i2v_video_sample.py --inference_config "configs/i2v_sketch.yaml" --examples "configs/i2v_sketch.jsonl" ``` ### Combine motion cloning with image-to-video ``` python i2v_video_sample.py --inference_config "configs/i2v_rgb.yaml" --examples "configs/i2v_rgb.jsonl" ``` ## 📎 Citation If you find this work helpful, please cite the following paper: ``` @article{ling2024motionclone, title={MotionClone: Training-Free Motion Cloning for Controllable Video Generation}, author={Ling, Pengyang and Bu, Jiazi and Zhang, Pan and Dong, Xiaoyi and Zang, Yuhang and Wu, Tong and Chen, Huaian and Wang, Jiaqi and Jin, Yi}, journal={arXiv preprint arXiv:2406.05338}, year={2024} } ``` ## 📣 Disclaimer This is official code of MotionClone. All the copyrights of the demo images and audio are from community users. Feel free to contact us if you would like remove them. ## 💞 Acknowledgements The code is built upon the below repositories, we thank all the contributors for open-sourcing. * [AnimateDiff](https://github.com/guoyww/AnimateDiff) * [FreeControl](https://github.com/genforce/freecontrol) ================================================ FILE: configs/i2v_rgb.jsonl ================================================ {"video_path":"reference_videos/camera_zoom_out.mp4", "condition_image_paths":["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass"} ================================================ FILE: configs/i2v_rgb.yaml ================================================ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" model_config: "configs/model_config/model_config.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" controlnet_config: "configs/sparsectrl/latent_condition.yaml" adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" cfg_scale: 7.5 # in default realistic classifer-free guidance negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers" inference_steps: 100 # the total denosing step for inference guidance_scale: 0.3 # which scale of time step to end guidance guidance_steps: 40 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance warm_up_steps: 10 cool_up_steps: 10 motion_guidance_weight: 2000 motion_guidance_blocks: ['up_blocks.1'] add_noise_step: 400 ================================================ FILE: configs/i2v_sketch.jsonl ================================================ {"video_path":"reference_videos/sample_white_tiger.mp4", "condition_image_paths":["condition_images/scribble/lion_forest.png"], "new_prompt": "Lion, walks in the forest"} ================================================ FILE: configs/i2v_sketch.yaml ================================================ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" model_config: "configs/model_config/model_config.yaml" controlnet_config: "configs/sparsectrl/image_condition.yaml" controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt" adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" cfg_scale: 7.5 # in default realistic classifer-free guidance negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers" inference_steps: 200 # the total denosing step for inference guidance_scale: 0.4 # which scale of time step to end guidance guidance_steps: 120 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance warm_up_steps: 10 cool_up_steps: 10 motion_guidance_weight: 2000 motion_guidance_blocks: ['up_blocks.1'] add_noise_step: 400 ================================================ FILE: configs/model_config/inference-v1.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true # from config v3 use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_decoder_only: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 zero_initialize: true # from config v3 noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: False ================================================ FILE: configs/model_config/inference-v2.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true unet_use_cross_frame_attention: false unet_use_temporal_attention: false use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: true motion_module_decoder_only: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: False ================================================ FILE: configs/model_config/inference-v3.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: Vanilla motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 zero_initialize: true noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: False ================================================ FILE: configs/model_config/model_config copy.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true # from config v3 use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 zero_initialize: true # from config v3 noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: False ================================================ FILE: configs/model_config/model_config.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_attention_dim_div: 1 zero_initialize: true noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: false ================================================ FILE: configs/model_config/model_config_public.yaml ================================================ unet_additional_kwargs: use_inflated_groupnorm: true # from config v3 unet_use_cross_frame_attention: false unet_use_temporal_attention: false use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_decoder_only: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self", "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 zero_initialize: true # from config v3 noise_scheduler_kwargs: beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" steps_offset: 1 clip_sample: False ================================================ FILE: configs/sparsectrl/image_condition.yaml ================================================ controlnet_additional_kwargs: set_noisy_sample_input_to_zero: true use_simplified_condition_embedding: false conditioning_channels: 3 use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 ================================================ FILE: configs/sparsectrl/latent_condition.yaml ================================================ controlnet_additional_kwargs: set_noisy_sample_input_to_zero: true use_simplified_condition_embedding: true conditioning_channels: 4 use_motion_module: true motion_module_resolutions: [1,2,4,8] motion_module_mid_block: false motion_module_type: "Vanilla" motion_module_kwargs: num_attention_heads: 8 num_transformer_block: 1 attention_block_types: [ "Temporal_Self" ] temporal_position_encoding: true temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 ================================================ FILE: configs/t2v_camera.jsonl ================================================ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Relics on the seabed", "seed": 42} {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "A road in the mountain", "seed": 42} {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Caves, a path for exploration", "seed": 2026} {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Railway for train"} {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Tree, in the mountain", "seed": 2026} {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Red car on the track", "seed": 2026} {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Man, standing in his garden.", "seed": 2026} {"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A island, on the ocean, sunny day"} {"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A tower, with fireworks"} {"video_path":"reference_videos/camera_pan_up.mp4", "new_prompt": "Beautiful house, around with flowers", "seed": 42} {"video_path":"reference_videos/camera_translation_2.mp4", "new_prompt": "Forest, in winter", "seed": 2028} {"video_path":"reference_videos/camera_pan_down.mp4", "new_prompt": "Eagle, standing in the tree", "seed": 2026} ================================================ FILE: configs/t2v_camera.yaml ================================================ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" model_config: "configs/model_config/model_config.yaml" cfg_scale: 7.5 # in default realistic classifer-free guidance negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers" postive_prompt: " 8k, high detailed, best quality, film grain, Fujifilm XT3" inference_steps: 100 # the total denosing step for inference guidance_scale: 0.3 # which scale of time step to end guidance 0.2/40 guidance_steps: 50 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance warm_up_steps: 10 cool_up_steps: 10 motion_guidance_weight: 2000 motion_guidance_blocks: ['up_blocks.1'] add_noise_step: 400 ================================================ FILE: configs/t2v_object.jsonl ================================================ {"video_path":"reference_videos/sample_astronaut.mp4", "new_prompt": "Robot, walks in the street.","seed":59} {"video_path":"reference_videos/sample_cat.mp4", "new_prompt": "Tiger, raises its head.", "seed": 2025} {"video_path":"reference_videos/sample_leaves.mp4", "new_prompt": "Petals falling in the wind.","seed":3407} {"video_path":"reference_videos/sample_fox.mp4", "new_prompt": "Cat, turns its head in the living room."} {"video_path":"reference_videos/sample_blackswan.mp4", "new_prompt": "Duck, swims in the river.","seed":3407} {"video_path":"reference_videos/sample_cow.mp4", "new_prompt": "Pig, drinks water on beach.","seed":3407} ================================================ FILE: configs/t2v_object.yaml ================================================ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" model_config: "configs/model_config/model_config.yaml" cfg_scale: 7.5 # in default realistic classifer-free guidance negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers" postive_prompt: "8k, high detailed, best quality, film grain, Fujifilm XT3" inference_steps: 300 # the total denosing step for inference guidance_scale: 0.4 # which scale of time step to end guidance guidance_steps: 180 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance warm_up_steps: 10 cool_up_steps: 10 motion_guidance_weight: 2000 motion_guidance_blocks: ['up_blocks.1',] add_noise_step: 400 ================================================ FILE: environment.yaml ================================================ name: motionclone channels: - pytorch - nvidia dependencies: - python=3.11.3 - pytorch=2.0.1 - torchvision=0.15.2 - pytorch-cuda=11.8 - pip - pip: - accelerate - diffusers==0.16.0 - transformers==4.28.1 - xformers==0.0.20 - imageio[ffmpeg] - decord==0.6.0 - gdown - einops - omegaconf - safetensors - gradio - wandb - triton - opencv-python ================================================ FILE: generated_videos/inference_config.json ================================================ motion_module: models/Motion_Module/v3_sd15_mm.ckpt dreambooth_path: models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors model_config: configs/model_config/model_config.yaml controlnet_config: configs/sparsectrl/image_condition.yaml controlnet_path: models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt adapter_lora_path: models/Motion_Module/v3_sd15_adapter.ckpt cfg_scale: 7.5 negative_prompt: ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers inference_steps: 200 guidance_scale: 0.4 guidance_steps: 120 warm_up_steps: 10 cool_up_steps: 10 motion_guidance_weight: 2000 motion_guidance_blocks: - up_blocks.1 add_noise_step: 400 width: 512 height: 512 video_length: 16 ================================================ FILE: i2v_video_sample.py ================================================ import argparse from omegaconf import OmegaConf import torch from diffusers import AutoencoderKL, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer from motionclone.models.unet import UNet3DConditionModel from motionclone.models.sparse_controlnet import SparseControlNetModel from motionclone.pipelines.pipeline_animation import AnimationPipeline from motionclone.utils.util import load_weights, auto_download from diffusers.utils.import_utils import is_xformers_available from motionclone.utils.motionclone_functions import * import json from motionclone.utils.xformer_attention import * def main(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0)) config = OmegaConf.load(args.inference_config) adopted_dtype = torch.float16 device = "cuda" set_all_seed(42) tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype) vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype) config.width = config.get("W", args.W) config.height = config.get("H", args.H) config.video_length = config.get("L", args.L) if not os.path.exists(args.generated_videos_save_dir): os.makedirs(args.generated_videos_save_dir) OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json")) model_config = OmegaConf.load(config.get("model_config", "")) unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype) # load controlnet model controlnet = None if config.get("controlnet_path", "") != "": # assert model_config.get("controlnet_images", "") != "" assert config.get("controlnet_config", "") != "" unet.config.num_attention_heads = 8 unet.config.projection_class_embeddings_input_dim = None controlnet_config = OmegaConf.load(config.controlnet_config) controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype) auto_download(config.controlnet_path, is_dreambooth_lora=False) print(f"loading controlnet checkpoint from {config.controlnet_path} ...") controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu") controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name} controlnet_state_dict.pop("animatediff_config", "") controlnet.load_state_dict(controlnet_state_dict) del controlnet_state_dict # set xformers if is_xformers_available() and (not args.without_xformers): unet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)), ).to(device) pipeline = load_weights( pipeline, # motion module motion_module_path = config.get("motion_module", ""), # domain adapter adapter_lora_path = config.get("adapter_lora_path", ""), adapter_lora_scale = config.get("adapter_lora_scale", 1.0), # image layer dreambooth_model_path = config.get("dreambooth_path", ""), ).to(device) pipeline.text_encoder.to(dtype=adopted_dtype) # customized functions in motionclone_functions pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler) pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler) pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet) pipeline.sample_video = sample_video.__get__(pipeline) pipeline.single_step_video = single_step_video.__get__(pipeline) pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline) pipeline.add_noise = add_noise.__get__(pipeline) pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline) pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline) for param in pipeline.unet.parameters(): param.requires_grad = False for param in pipeline.controlnet.parameters(): param.requires_grad = False pipeline.input_config, pipeline.unet.input_config = config, config pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks) pipeline.unet = prep_unet_conv(pipeline.unet) pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven") with open(args.examples, 'r') as files: for line in files: # prepare infor of each case example_infor = json.loads(line) config.video_path = example_infor["video_path"] config.condition_image_path_list = example_infor["condition_image_paths"] config.image_index = example_infor.get("image_index",[0]) assert len(config.image_index) == len(config.condition_image_path_list) config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "") config.controlnet_scale = example_infor.get("controlnet_scale", 1.0) pipeline.input_config, pipeline.unet.input_config = config, config # update config # perform motion representation extraction seed_motion = seed_motion = example_infor.get("seed", args.default_seed) generator = torch.Generator(device=pipeline.device) generator.manual_seed(seed_motion) if not os.path.exists(args.motion_representation_save_dir): os.makedirs(args.motion_representation_save_dir) motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt') pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,) # perform video generation seed = seed_motion # can assign other seed here generator = torch.Generator(device=pipeline.device) generator.manual_seed(seed) pipeline.input_config.seed = seed videos = pipeline.sample_video(generator = generator, add_controlnet=True,) videos = rearrange(videos, "b c f h w -> b f h w c") save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4') videos_uint8 = (videos[0] * 255).astype(np.uint8) imageio.mimwrite(save_path, videos_uint8, fps=8) print(save_path,"is done") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",) parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml") parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl") parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/") parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/") parser.add_argument("--visible_gpu", type=str, default=None) parser.add_argument("--default-seed", type=int, default=76739) parser.add_argument("--L", type=int, default=16) parser.add_argument("--W", type=int, default=512) parser.add_argument("--H", type=int, default=512) parser.add_argument("--without-xformers", action="store_true") args = parser.parse_args() main(args) ================================================ FILE: models/Motion_Module/Put motion module checkpoints here.txt ================================================ ================================================ FILE: motionclone/models/attention.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import FeedForward, AdaLayerNorm from einops import rearrange, repeat import pdb @dataclass class Transformer3DModelOutput(BaseOutput): sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None class Transformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, unet_use_cross_frame_attention=None, unet_use_temporal_attention=None, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # Define input layers self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) # Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) for d in range(num_layers) ] ) # 4. Define output layers if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): # Input assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." video_length = hidden_states.shape[2] hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) # Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, video_length=video_length ) # Output if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() ) output = hidden_states + residual output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) if not return_dict: return (output,) return Transformer3DModelOutput(sample=output) class BasicTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, unet_use_cross_frame_attention = None, unet_use_temporal_attention = None, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = num_embeds_ada_norm is not None self.unet_use_cross_frame_attention = unet_use_cross_frame_attention self.unet_use_temporal_attention = unet_use_temporal_attention # SC-Attn assert unet_use_cross_frame_attention is not None if unet_use_cross_frame_attention: self.attn1 = SparseCausalAttention2D( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) else: self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) # Cross-Attn if cross_attention_dim is not None: self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) else: self.attn2 = None if cross_attention_dim is not None: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) else: self.norm2 = None # Feed-forward self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.norm3 = nn.LayerNorm(dim) # Temp-Attn assert unet_use_temporal_attention is not None if unet_use_temporal_attention: self.attn_temp = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) nn.init.zeros_(self.attn_temp.to_out[0].weight.data) self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers", name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" " available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers if self.attn2 is not None: self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): # SparseCausal-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) # if self.only_cross_attention: # hidden_states = ( # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states # ) # else: # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states # pdb.set_trace() if self.unet_use_cross_frame_attention: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states else: hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states if self.attn2 is not None: # Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) hidden_states = ( self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) + hidden_states ) # Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states # Temporal-Attention if self.unet_use_temporal_attention: d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) norm_hidden_states = ( self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) ) hidden_states = self.attn_temp(norm_hidden_states) + hidden_states hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states class CrossAttention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.scale = dim_head**-0.5 self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self._slice_size = None self._use_memory_efficient_attention_xformers = False self.added_kv_proj_dim = added_kv_proj_dim #### add processer self.processor = None if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) else: self.group_norm = None self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) self.to_out.append(nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") self._slice_size = slice_size def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape encoder_hidden_states = encoder_hidden_states if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) dim = query.shape[-1] # query = self.reshape_heads_to_batch_dim(query) # move backwards if self.added_kv_proj_dim is not None: key = self.to_k(hidden_states) value = self.to_v(hidden_states) encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) ######record###### record before reshape heads to batch dim if self.processor is not None: self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask) ################## key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) else: encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) ######record###### if self.processor is not None: self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask) ################## key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) query = self.reshape_heads_to_batch_dim(query) # reshape query if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) ######record###### if self.processor is not None: self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask) ################## # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value, attention_mask) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states def _attention(self, query, key, value, attention_mask=None): if self.upcast_attention: query = query.float() key = key.float() attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, key.transpose(-1, -2), beta=0, alpha=self.scale, ) if attention_mask is not None: attention_scores = attention_scores + attention_mask if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) # cast back to the original dtype attention_probs = attention_probs.to(value.dtype) # compute attention output hidden_states = torch.bmm(attention_probs, value) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype ) slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] if self.upcast_attention: query_slice = query_slice.float() key_slice = key_slice.float() attn_slice = torch.baddbmm( torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), query_slice, key_slice.transpose(-1, -2), beta=0, alpha=self.scale, ) if attention_mask is not None: attn_slice = attn_slice + attention_mask[start_idx:end_idx] if self.upcast_softmax: attn_slice = attn_slice.float() attn_slice = attn_slice.softmax(dim=-1) # cast back to the original dtype attn_slice = attn_slice.to(value.dtype) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): # TODO attention_mask query = query.contiguous() key = key.contiguous() value = value.contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def set_processor(self, processor: "AttnProcessor") -> None: r""" Set the attention processor to use. Args: processor (`AttnProcessor`): The attention processor to use. """ # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def get_attention_scores( self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None ) -> torch.Tensor: r""" Compute the attention scores. Args: query (`torch.Tensor`): The query tensor. key (`torch.Tensor`): The key tensor. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. Returns: `torch.Tensor`: The attention probabilities/scores. """ dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs ================================================ FILE: motionclone/models/motion_module.py ================================================ from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import numpy as np import torch.nn.functional as F from torch import nn import torchvision from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import FeedForward from .attention import CrossAttention from einops import rearrange, repeat import math def zero_module(module): # Zero out the parameters of a module and return it. for p in module.parameters(): p.detach().zero_() return module @dataclass class TemporalTransformer3DModelOutput(BaseOutput): sample: torch.FloatTensor if is_xformers_available(): import xformers import xformers.ops else: xformers = None def get_motion_module( # 只能返回VanillaTemporalModule类 in_channels, motion_module_type: str, motion_module_kwargs: dict ): if motion_module_type == "Vanilla": return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) else: raise ValueError class VanillaTemporalModule(nn.Module): def __init__( self, in_channels, num_attention_heads = 8, num_transformer_block = 2, attention_block_types =( "Temporal_Self", "Temporal_Self" ), cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 32, temporal_attention_dim_div = 1, zero_initialize = True, ): super().__init__() self.temporal_transformer = TemporalTransformer3DModel( in_channels=in_channels, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, num_layers=num_transformer_block, attention_block_types=attention_block_types, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) if zero_initialize: self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): hidden_states = input_tensor hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) output = hidden_states return output class TemporalTransformer3DModel(nn.Module): def __init__( self, in_channels, num_attention_heads, attention_head_dim, num_layers, attention_block_types = ( "Temporal_Self", "Temporal_Self", ), # 两个TempAttn dropout = 0.0, norm_num_groups = 32, cross_attention_dim = 768, activation_fn = "geglu", attention_bias = False, upcast_attention = False, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ TemporalTransformerBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, attention_block_types=attention_block_types, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, upcast_attention=upcast_attention, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." video_length = hidden_states.shape[2] hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") batch, channel, height, weight = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = self.proj_in(hidden_states) # Transformer Blocks for block in self.transformer_blocks: hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) # output hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) return output class TemporalTransformerBlock(nn.Module): def __init__( self, dim, num_attention_heads, attention_head_dim, attention_block_types = ( "Temporal_Self", "Temporal_Self", ), dropout = 0.0, norm_num_groups = 32, cross_attention_dim = 768, activation_fn = "geglu", attention_bias = False, upcast_attention = False, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, ): super().__init__() attention_blocks = [] norms = [] for block_name in attention_block_types: attention_blocks.append( VersatileAttention( attention_mode=block_name.split("_")[0], cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_frame_attention_mode=cross_frame_attention_mode, temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) norms.append(nn.LayerNorm(dim)) self.attention_blocks = nn.ModuleList(attention_blocks) self.norms = nn.ModuleList(norms) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.ff_norm = nn.LayerNorm(dim) def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): for attention_block, norm in zip(self.attention_blocks, self.norms): norm_hidden_states = norm(hidden_states) hidden_states = attention_block( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, video_length=video_length, ) + hidden_states hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states output = hidden_states return output class PositionalEncoding(nn.Module): def __init__( self, d_model, dropout = 0., max_len = 24 ): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) # self.register_buffer('pe', pe) self.register_buffer('pe', pe, persistent=False) def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) class VersatileAttention(CrossAttention): # 继承CrossAttention类,不需要在额外写set_processor功能 def __init__( self, attention_mode = None, cross_frame_attention_mode = None, temporal_position_encoding = False, temporal_position_encoding_max_len = 24, *args, **kwargs ): super().__init__(*args, **kwargs) assert attention_mode == "Temporal" self.attention_mode = attention_mode self.is_cross_attention = kwargs["cross_attention_dim"] is not None self.pos_encoder = PositionalEncoding( kwargs["query_dim"], dropout=0., max_len=temporal_position_encoding_max_len ) if (temporal_position_encoding and attention_mode == "Temporal") else None def extra_repr(self): return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): batch_size, sequence_length, _ = hidden_states.shape if self.attention_mode == "Temporal": d = hidden_states.shape[1] hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) if self.pos_encoder is not None: hidden_states = self.pos_encoder(hidden_states) encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states else: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) dim = query.shape[-1] # query = self.reshape_heads_to_batch_dim(query) # move backwards if self.added_kv_proj_dim is not None: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) ######record###### record before reshape heads to batch dim if self.processor is not None: self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask) ################## key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) query = self.reshape_heads_to_batch_dim(query) # reshape query here if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) ######record###### # if self.processor is not None: # self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask) ################## # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value, attention_mask) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) if self.attention_mode == "Temporal": hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) return hidden_states ================================================ FILE: motionclone/models/resnet.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class InflatedConv3d(nn.Conv2d): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x class InflatedGroupNorm(nn.GroupNorm): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x class Upsample3D(nn.Module): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: raise NotImplementedError elif use_conv: self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: raise NotImplementedError # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # if self.use_conv: # if self.name == "conv": # hidden_states = self.conv(hidden_states) # else: # hidden_states = self.Conv2d_0(hidden_states) hidden_states = self.conv(hidden_states) return hidden_states class Downsample3D(nn.Module): def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: raise NotImplementedError def forward(self, hidden_states): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: raise NotImplementedError assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class ResnetBlock3D(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", time_embedding_norm="default", output_scale_factor=1.0, use_in_shortcut=None, use_inflated_groupnorm=False, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.output_scale_factor = output_scale_factor self.upsample = self.downsample = None if groups_out is None: groups_out = groups assert use_inflated_groupnorm != None if use_inflated_groupnorm: self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": time_emb_proj_out_channels = out_channels elif self.time_embedding_norm == "scale_shift": time_emb_proj_out_channels = out_channels * 2 else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) else: self.time_emb_proj = None if use_inflated_groupnorm: self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) elif non_linearity == "mish": self.nonlinearity = Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, input_tensor, temb): hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class Mish(torch.nn.Module): def forward(self, hidden_states): return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) ================================================ FILE: motionclone/models/scheduler.py ================================================ from typing import Optional, Tuple, Union import torch from diffusers import DDIMScheduler from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput from diffusers.utils.torch_utils import randn_tensor class CustomDDIMScheduler(DDIMScheduler): @torch.no_grad() def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, # Guidance parameters score=None, guidance_scale=0.0, indices=None, # [0] ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # Support IF models if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf if score is not None and guidance_scale > 0.0: # indices指定了应用guidance的位置,此处indices = [0] if indices is not None: # import pdb; pdb.set_trace() assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape" pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # 只修改了其中第一个[1, 4, 64, 64]的部分 else: assert pred_epsilon.shape == score.shape pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # [2, 4, 64, 64] # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # [2, 4, 64, 64] if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise # 最后还要再加一些随机噪声 prev_sample = prev_sample + variance # [2, 4, 64, 64] self.pred_epsilon = pred_epsilon if not return_dict: return (prev_sample,) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) ================================================ FILE: motionclone/models/sparse_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Changes were made to this source code by Yuwei Guo. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from .unet_blocks import ( CrossAttnDownBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, get_down_block, ) from einops import repeat, rearrange from .resnet import InflatedConv3d from diffusers.models.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class SparseControlNetOutput(BaseOutput): down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class SparseControlNetConditioningEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class SparseControlNetModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, use_motion_module = True, motion_module_resolutions = ( 1,2,4,8 ), motion_module_mid_block = False, motion_module_type = "Vanilla", motion_module_kwargs = { "num_attention_heads": 8, "num_transformer_block": 1, "attention_block_types": ["Temporal_Self"], "temporal_position_encoding": True, "temporal_position_encoding_max_len": 32, "temporal_attention_dim_div": 1, "causal_temporal_attention": False, }, concate_conditioning_mask: bool = True, use_simplified_condition_embedding: bool = False, set_noisy_sample_input_to_zero: bool = False, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = InflatedConv3d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) if concate_conditioning_mask: conditioning_channels = conditioning_channels + 1 self.concate_conditioning_mask = concate_conditioning_mask # control net conditioning embedding if use_simplified_condition_embedding: self.controlnet_cond_embedding = zero_module( InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) ).to(torch.float16) else: self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ).to(torch.float16) self.use_simplified_condition_embedding = use_simplified_condition_embedding # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=True, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock3DCrossAttn( in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, use_inflated_groupnorm=True, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, controlnet_additional_kwargs: dict = {}, ): controlnet = cls( in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, **controlnet_additional_kwargs, ) if load_weights_from_unet: m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False) assert len(u) == 0 if controlnet.class_embedding: m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False) assert len(u) == 0 m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False) assert len(u) == 0 return controlnet @staticmethod def image_layer_filter(state_dict): new_state_dict = {} for name, param in state_dict.items(): if "motion_modules." in name or "lora" in name: continue new_state_dict[name] = param return new_state_dict # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_mask: Optional[torch.FloatTensor] = None, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[SparseControlNetOutput, Tuple]: # set input noise to zero # if self.set_noisy_sample_input_to_zero: # sample = torch.zeros_like(sample).to(sample.device) # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process # equal to set input noise to zero if self.set_noisy_sample_input_to_zero: shape = sample.shape sample = self.conv_in.bias.reshape(1,-1,1,1,1).expand(shape[0],-1,shape[2],shape[3],shape[4]) else: sample = self.conv_in(sample) if self.concate_conditioning_mask: controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1).to(torch.float16) # import pdb; pdb.set_trace() controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) # 5. controlnet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return SparseControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: motionclone/models/unet.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py from dataclasses import dataclass from typing import List, Optional, Tuple, Union import os import json import pdb import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from .unet_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) from .resnet import InflatedConv3d, InflatedGroupNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor class UNet3DConditionModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), mid_block_type: str = "UNetMidBlock3DCrossAttn", up_block_types: Tuple[str] = ( # 第一个不带有CrossAttn,后面三个带有CrossAttn "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", use_inflated_groupnorm=False, # Additional use_motion_module = False, motion_module_resolutions = ( 1,2,4,8 ), motion_module_mid_block = False, motion_module_decoder_only = False, motion_module_type = None, motion_module_kwargs = {}, unet_use_cross_frame_attention = False, unet_use_temporal_attention = False, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # input self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock3DCrossAttn": self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the videos self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): res = 2 ** (3 - i) is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if use_inflated_groupnorm: self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_slicable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_slicable_dims(module) num_slicable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_slicable_layers * [1] slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, # support controlnet down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # pre-process sample = self.conv_in(sample) # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples # support controlnet down_block_res_samples = list(down_block_res_samples) if down_block_additional_residuals is not None: for i, down_block_additional_residual in enumerate(down_block_additional_residuals): if down_block_additional_residual.dim() == 4: # boardcast down_block_additional_residual = down_block_additional_residual.unsqueeze(2) down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual # mid sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) # support controlnet if mid_block_additional_residual is not None: if mid_block_additional_residual.dim() == 4: # boardcast mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) sample = sample + mid_block_additional_residual # up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) @classmethod def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) config["_class_name"] = cls.__name__ config["down_block_types"] = [ "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D" ] config["up_block_types"] = [ "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ] from diffusers.utils import WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") state_dict = torch.load(model_file, map_location="cpu") m, u = model.load_state_dict(state_dict, strict=False) print(f"### motion keys will be loaded: {len(m)}; \n### unexpected keys: {len(u)};") params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] print(f"### Motion Module Parameters: {sum(params) / 1e6} M") return model ================================================ FILE: motionclone/models/unet_blocks.py ================================================ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py import torch from torch import nn from .attention import Transformer3DModel from .resnet import Downsample3D, ResnetBlock3D, Upsample3D from .motion_module import get_motion_module import pdb def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, use_inflated_groupnorm=use_inflated_groupnorm, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ] attentions = [] motion_modules = [] for _ in range(num_layers): if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, in_channels // attn_num_head_channels, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=in_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states hidden_states = resnet(hidden_states, temb) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): output_states = () for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, )[0] if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] motion_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample3D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () for resnet, motion_module in zip(self.resnets, self.motion_modules): if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attn_num_head_channels=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.attn_num_head_channels = attn_num_head_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) if dual_cross_attention: raise NotImplementedError attentions.append( Transformer3DModel( attn_num_head_channels, out_channels // attn_num_head_channels, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None, attention_mask=None, ): for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, )[0] if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample # add motion module hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock3D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, use_inflated_groupnorm=False, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): super().__init__() resnets = [] motion_modules = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock3D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( get_motion_module( in_channels=out_channels, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): for resnet, motion_module in zip(self.resnets, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) if motion_module is not None: hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) else: hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states ================================================ FILE: motionclone/pipelines/pipeline_animation.py ================================================ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py import inspect from typing import Callable, List, Optional, Union, Any, Dict from dataclasses import dataclass from diffusers import StableDiffusionPipeline, DDIMInverseScheduler import os import pickle import numpy as np import torch from tqdm import tqdm import omegaconf from omegaconf import OmegaConf import yaml from diffusers.utils import is_accelerate_available from packaging import version from transformers import CLIPTextModel, CLIPTokenizer from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from diffusers.utils import deprecate, logging, BaseOutput from einops import rearrange from ..models.unet import UNet3DConditionModel from ..models.sparse_controlnet import SparseControlNetModel from ..utils.xformer_attention import * from ..utils.conv_layer import * from ..utils.util import * logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class AnimationPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] class AnimationPipeline(DiffusionPipeline): _optional_components = [] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], controlnet: Union[SparseControlNetModel, None] = None, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, controlnet=controlnet, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def enable_vae_slicing(self): self.vae.enable_slicing() def disable_vae_slicing(self): self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings @torch.no_grad() def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") # video = self.vae.decode(latents).sample video = [] for frame_idx in tqdm(range(latents.shape[0])): video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) video = torch.cat(video) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 video = video.cpu().float().numpy() return video def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, prompt, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: rand_device = "cpu" if device.type == "mps" else device if isinstance(generator, list): shape = shape # shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, save_probs = False, # support controlnet controlnet_images: torch.FloatTensor = None, controlnet_image_index: list = [0], controlnet_conditioning_scale: Union[float, List[float]] = 1.0, **kwargs, ): # to record temp attention probs self.unet = prep_unet_attention(self.unet) self.unet = prep_unet_conv(self.unet) self.guidance_config = guidance_scale self.temp_rec = {} # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters # batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if latents is not None: batch_size = latents.shape[0] if isinstance(prompt, list): batch_size = len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size if negative_prompt is not None: negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size text_embeddings = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) # import pdb; pdb.set_trace() # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Prepare latent variables num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, video_length, height, width, text_embeddings.dtype, device, generator, latents, ) latents_dtype = latents.dtype # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) down_block_additional_residuals = mid_block_additional_residual = None # import pdb; pdb.set_trace() if (getattr(self, "controlnet", None) != None) and (controlnet_images != None): assert controlnet_images.dim() == 5 controlnet_noisy_latents = latent_model_input controlnet_prompt_embeds = text_embeddings controlnet_images = controlnet_images.to(latents.device) controlnet_cond_shape = list(controlnet_images.shape) controlnet_cond_shape[2] = video_length controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) controlnet_conditioning_mask_shape = list(controlnet_cond.shape) controlnet_conditioning_mask_shape[1] = 1 controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) assert controlnet_images.shape[2] >= len(controlnet_image_index) controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 down_block_additional_residuals, mid_block_additional_residual = self.controlnet( controlnet_noisy_latents, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_conditioning_mask, conditioning_scale=controlnet_conditioning_scale, guess_mode=False, return_dict=False, ) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=text_embeddings, down_block_additional_residuals = down_block_additional_residuals, mid_block_additional_residual = mid_block_additional_residual, ).sample.to(dtype=latents_dtype) # get temp attn probs if save_probs: temp_attn_prob = self.get_temp_attn_prob() for key in temp_attn_prob.keys(): temp_attn_prob[key] = temp_attn_prob[key].chunk(2, dim = 0)[0].detach().clone().cpu() self.temp_rec[i] = temp_attn_prob # import pdb; pdb.set_trace() # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # pickle temp attn prob if save_probs: with open('temp_dic.pkl', 'wb') as f: pickle.dump(self.temp_rec, f) # Post-processing video = self.decode_latents(latents) # Convert to tensor if output_type == "tensor": video = torch.from_numpy(video) if not return_dict: return video return AnimationPipelineOutput(videos=video) ================================================ FILE: motionclone/utils/conv_layer.py ================================================ import torch def conv_forward(self): def forward(input_tensor, temb, scale=1.0): hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) # import pdb; pdb.set_trace() if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None].repeat(1, 1, hidden_states.shape[2], 1, 1) if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) # record hidden state self.record_hidden_state = hidden_states if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor return forward def get_conv_feat(unet): hidden_state_dict = dict() for i in range(len(unet.up_blocks)): for j in range(len(unet.up_blocks[i].resnets)): module = unet.up_blocks[i].resnets[j] module_name = f"up_blocks.{i}.resnets.{j}" # print(module_name) hidden_state_dict[module_name] = module.record_hidden_state return hidden_state_dict def prep_unet_conv(unet): for i in range(len(unet.up_blocks)): for j in range(len(unet.up_blocks[i].resnets)): module = unet.up_blocks[i].resnets[j] module.forward = conv_forward(module) return unet ================================================ FILE: motionclone/utils/convert_from_ckpt.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" import re from io import BytesIO from typing import Optional import requests import torch from transformers import ( AutoFeatureExtractor, BertTokenizerFast, CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionConfig, CLIPVisionModelWithProjection, ) from diffusers.models import ( AutoencoderKL, PriorTransformer, UNet2DConditionModel, ) from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UnCLIPScheduler, ) from diffusers.utils.import_utils import BACKENDS_MAPPING def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: return ".".join(path.split(".")[n_shave_prefix_segments:]) else: return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item.replace("in_layers.0", "norm1") new_item = new_item.replace("in_layers.2", "conv1") new_item = new_item.replace("out_layers.0", "norm2") new_item = new_item.replace("out_layers.3", "conv2") new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item # new_item = new_item.replace('norm.weight', 'group_norm.weight') # new_item = new_item.replace('norm.bias', 'group_norm.bias') # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.bias", "query.bias") new_item = new_item.replace("k.weight", "key.weight") new_item = new_item.replace("k.bias", "key.bias") new_item = new_item.replace("v.weight", "value.weight") new_item = new_item.replace("v.bias", "value.bias") new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] channels = old_tensor.shape[0] // 3 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear if "proj_attn.weight" in new_path: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: unet_params = original_config.model.params.control_stage_config.params else: unet_params = original_config.model.params.unet_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim = [5, 10, 20, 20] class_embed_type = None projection_class_embeddings_input_dim = None if "num_classes" in unet_params: if unet_params.num_classes == "sequential": class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") config = { "sample_size": image_size // vae_scale_factor, "in_channels": unet_params.in_channels, "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, "cross_attention_dim": unet_params.context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, } if not controlnet: config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types) return config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ vae_params = original_config.model.params.first_stage_config.params.ddconfig _ = original_config.model.params.first_stage_config.params.embed_dim block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params.in_channels, "out_channels": vae_params.out_ch, "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), "latent_channels": vae_params.z_channels, "layers_per_block": vae_params.num_res_blocks, } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config.model.params.timesteps, beta_start=original_config.model.params.linear_start, beta_end=original_config.model.params.linear_end, beta_schedule="scaled_linear", ) return schedular def create_ldm_bert_config(original_config): bert_params = original_config.model.parms.cond_stage_config.params config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, encoder_ffn_dim=bert_params.n_embed * 4, ) return config def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) if controlnet: unet_key = "control_model." else: unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: print(f"Checkpoint {path} has both EMA and non-EMA weights.") print( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: print( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port ... elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] if len(attentions): paths = renew_attention_paths(attentions) meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] if controlnet: # conditioning embedding orig_index = 0 new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) orig_index += 2 diffusers_index = 0 while diffusers_index < 6: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint def convert_ldm_bert_checkpoint(checkpoint, config): def _copy_attn_layer(hf_attn_layer, pt_attn_layer): hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias def _copy_linear(hf_linear, pt_linear): hf_linear.weight = pt_linear.weight hf_linear.bias = pt_linear.bias def _copy_layer(hf_layer, pt_layer): # copy layer norms _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) # copy attn _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) # copy MLP pt_mlp = pt_layer[1][1] _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) _copy_linear(hf_layer.fc2, pt_mlp.net[2]) def _copy_layers(hf_layers, pt_layers): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) # copy hidden layers _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) return hf_model def convert_ldm_clip_checkpoint_concise(checkpoint): keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] return text_model_dict def convert_ldm_clip_checkpoint(checkpoint): text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] text_model.load_state_dict(text_model_dict,strict=True) return text_model textenc_conversion_lst = [ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} textenc_transformer_conversion_lst = [ # (stable-diffusion, HF Diffusers) ("resblocks.", "text_model.encoder.layers."), ("ln_1", "layer_norm1"), ("ln_2", "layer_norm2"), (".c_fc.", ".fc1."), (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint): config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") model = PaintByExampleImageEncoder(config) keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] # load clip vision model.model.load_state_dict(text_model_dict) # load mapper keys_mapper = { k[len("cond_stage_model.mapper.res") :]: v for k, v in checkpoint.items() if k.startswith("cond_stage_model.mapper") } MAPPING = { "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], "attn.c_proj": ["attn1.to_out.0"], "ln_1": ["norm1"], "ln_2": ["norm3"], "mlp.c_fc": ["ff.net.0.proj"], "mlp.c_proj": ["ff.net.2"], } mapped_weights = {} for key, value in keys_mapper.items(): prefix = key[: len("blocks.i")] suffix = key.split(prefix)[-1].split(".")[-1] name = key.split(prefix)[-1].split(suffix)[0][1:-1] mapped_names = MAPPING[name] num_splits = len(mapped_names) for i, mapped_name in enumerate(mapped_names): new_name = ".".join([prefix, mapped_name, suffix]) shape = value.shape[0] // num_splits mapped_weights[new_name] = value[i * shape : (i + 1) * shape] model.mapper.load_state_dict(mapped_weights) # load final layer norm model.final_layer_norm.load_state_dict( { "bias": checkpoint["cond_stage_model.final_ln.bias"], "weight": checkpoint["cond_stage_model.final_ln.weight"], } ) # load final proj model.proj_out.load_state_dict( { "bias": checkpoint["proj_out.bias"], "weight": checkpoint["proj_out.weight"], } ) # load uncond vector model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) return model def convert_open_clip_checkpoint(checkpoint): text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") keys = list(checkpoint.keys()) text_model_dict = {} if "cond_stage_model.model.text_projection" in checkpoint: d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) else: d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer continue if key in textenc_conversion_map: text_model_dict[textenc_conversion_map[key]] = checkpoint[key] if key.startswith("cond_stage_model.model.transformer."): new_key = key[len("cond_stage_model.model.transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] else: new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key] = checkpoint[key] text_model.load_state_dict(text_model_dict) return text_model def stable_unclip_image_encoder(original_config): """ Returns the image processor and clip image encoder for the img2img unclip pipeline. We currently know of two types of stable unclip models which separately use the clip and the openclip image encoders. """ image_embedder_config = original_config.model.params.embedder_config sd_clip_image_embedder_class = image_embedder_config.target sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] if sd_clip_image_embedder_class == "ClipImageEmbedder": clip_model_name = image_embedder_config.params.model if clip_model_name == "ViT-L/14": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") else: raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") else: raise NotImplementedError( f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" ) return feature_extractor, image_encoder def stable_unclip_image_noising_components( original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None ): """ Returns the noising components for the img2img and txt2img unclip pipelines. Converts the stability noise augmentor into 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats 2. a `DDPMScheduler` for holding the noise schedule If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ noise_aug_config = original_config.model.params.noise_aug_config noise_aug_class = noise_aug_config.target noise_aug_class = noise_aug_class.split(".")[-1] if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": noise_aug_config = noise_aug_config.params embedding_dim = noise_aug_config.timestep_dim max_noise_level = noise_aug_config.noise_schedule_config.timesteps beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: raise ValueError("This stable unclip config requires a `clip_stats_path`") clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) clip_mean = clip_mean[None, :] clip_std = clip_std[None, :] clip_stats_state_dict = { "mean": clip_mean, "std": clip_std, } image_normalizer.load_state_dict(clip_stats_state_dict) else: raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") return image_normalizer, image_noising_scheduler def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") controlnet_model = ControlNetModel(**ctrlnet_config) converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True ) controlnet_model.load_state_dict(converted_ctrl_checkpoint) return controlnet_model ================================================ FILE: motionclone/utils/convert_lora_safetensor_to_diffusers.py ================================================ # coding=utf-8 # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Changes were made to this source code by Yuwei Guo. """ Conversion script for the LoRA's safetensors checkpoints. """ import argparse import torch from safetensors.torch import load_file from diffusers import StableDiffusionPipeline def load_diffusers_lora(pipeline, state_dict, alpha=1.0): # directly update weight in diffusers model for key in state_dict: # only process lora down key if "up." in key: continue up_key = key.replace(".down.", ".up.") model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") model_key = model_key.replace("to_out.", "to_out.0.") layer_infos = model_key.split(".")[:-1] curr_layer = pipeline.unet while len(layer_infos) > 0: temp_name = layer_infos.pop(0) curr_layer = curr_layer.__getattr__(temp_name) weight_down = state_dict[key] weight_up = state_dict[up_key] curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) return pipeline def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): # load base model # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) # load LoRA weight from .safetensors # state_dict = load_file(checkpoint_path) visited = [] # directly update weight in diffusers model for key in state_dict: # it is suggested to print out the key, it usually will be something like below # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" # as we have set the alpha beforehand, so just skip if ".alpha" in key or key in visited: continue if "text" in key: layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") curr_layer = pipeline.text_encoder else: layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") curr_layer = pipeline.unet # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += "_" + layer_infos.pop(0) else: temp_name = layer_infos.pop(0) pair_keys = [] if "lora_down" in key: pair_keys.append(key.replace("lora_down", "lora_up")) pair_keys.append(key) else: pair_keys.append(key) pair_keys.append(key.replace("lora_up", "lora_down")) # update weight if len(state_dict[pair_keys[0]].shape) == 4: weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) else: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) # update visited list for item in pair_keys: visited.append(item) return pipeline if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." ) parser.add_argument( "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument( "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" ) parser.add_argument( "--lora_prefix_text_encoder", default="lora_te", type=str, help="The prefix of text encoder weight in safetensors", ) parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") parser.add_argument( "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." ) parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") args = parser.parse_args() base_model_path = args.base_model_path checkpoint_path = args.checkpoint_path dump_path = args.dump_path lora_prefix_unet = args.lora_prefix_unet lora_prefix_text_encoder = args.lora_prefix_text_encoder alpha = args.alpha pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) pipe = pipe.to(args.device) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) ================================================ FILE: motionclone/utils/motionclone_functions.py ================================================ from dataclasses import dataclass import os import numpy as np import torch import matplotlib.pyplot as plt import matplotlib.colors as mcolors from typing import Callable, List, Optional, Union from diffusers.utils import deprecate, logging, BaseOutput from .xformer_attention import * from .conv_layer import * from .util import * from diffusers.utils.torch_utils import randn_tensor from typing import List, Optional, Tuple, Union logger = logging.get_logger(__name__) # pylint: disable=invalid-name from motionclone.utils.util import video_preprocess import einops import torchvision.transforms as transforms def add_noise(self, timestep, x_0, noise_pred): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t latents_noise = alpha_prod_t ** 0.5 * x_0 + beta_prod_t ** 0.5 * noise_pred return latents_noise @torch.no_grad() def obtain_motion_representation(self, generator=None, motion_representation_path: str = None, duration=None,use_controlnet=False,): video_data = video_preprocess(self.input_config.video_path, self.input_config.height, self.input_config.width, self.input_config.video_length,duration=duration) video_latents = self.vae.encode(video_data.to(self.vae.dtype).to(self.vae.device)).latent_dist.mode() video_latents = self.vae.config.scaling_factor * video_latents video_latents = video_latents.unsqueeze(0) video_latents = einops.rearrange(video_latents, "b f c h w -> b c f h w") uncond_input = self.tokenizer( [""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" ) step_t = int(self.input_config.add_noise_step) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] noise_sampled = randn_tensor(video_latents.shape, generator=generator, device=video_latents.device, dtype=video_latents.dtype) noisy_latents = self.add_noise(step_t, video_latents, noise_sampled) down_block_additional_residuals = mid_block_additional_residual = None if use_controlnet: controlnet_image_index = self.input_config.image_index if self.controlnet.use_simplified_condition_embedding: controlnet_images = video_latents[:,:,controlnet_image_index,:,:] else: controlnet_images = (einops.rearrange(video_data.unsqueeze(0).to(self.vae.dtype).to(self.vae.device), "b f c h w -> b c f h w")+1)/2 controlnet_images = controlnet_images[:,:,controlnet_image_index,:,:] controlnet_cond_shape = list(controlnet_images.shape) controlnet_cond_shape[2] = noisy_latents.shape[2] controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_conditioning_mask_shape = list(controlnet_cond.shape) controlnet_conditioning_mask_shape[1] = 1 controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_cond[:,:,controlnet_image_index] = controlnet_images controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 down_block_additional_residuals, mid_block_additional_residual = self.controlnet( noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_conditioning_mask, conditioning_scale=self.input_config.controlnet_scale, guess_mode=False, return_dict=False, ) _ = self.unet(noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, return_dict=False, only_motion_feature=True, down_block_additional_residuals = down_block_additional_residuals, mid_block_additional_residual = mid_block_additional_residual,) temp_attn_prob_control = self.get_temp_attn_prob() motion_representation = { key: [max_value, max_index.to(torch.uint8)] for key, tensor in temp_attn_prob_control.items() for max_value, max_index in [torch.topk(tensor, k=1, dim=-1)]} torch.save(motion_representation, motion_representation_path) self.motion_representation_path = motion_representation_path def compute_temp_loss(self, temp_attn_prob_control_dict): temp_attn_prob_loss = [] for name in temp_attn_prob_control_dict.keys(): current_temp_attn_prob = temp_attn_prob_control_dict[name] reference_representation_dict = self.motion_representation_dict[name] max_index = reference_representation_dict[1].to(torch.int64).to(current_temp_attn_prob.device) current_motion_representation = torch.gather(current_temp_attn_prob, index = max_index, dim=-1) reference_motion_representation = reference_representation_dict[0].to(dtype = current_motion_representation.dtype, device = current_motion_representation.device) module_attn_loss = F.mse_loss(current_motion_representation, reference_motion_representation.detach()) temp_attn_prob_loss.append(module_attn_loss) loss_temp = torch.stack(temp_attn_prob_loss) return loss_temp.sum() def sample_video( self, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, noisy_latents: Optional[torch.FloatTensor] = None, add_controlnet: bool = False, ): # Determine if use controlnet, i.e., conditional image2video self.add_controlnet = add_controlnet if self.add_controlnet: image_transforms = transforms.Compose([ transforms.Resize((self.input_config.height, self.input_config.width)), transforms.ToTensor(), ]) controlnet_images = [image_transforms(Image.open(path).convert("RGB")) for path in self.input_config.condition_image_path_list] controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(dtype=self.vae.dtype,device=self.vae.device) controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") with torch.no_grad(): if self.controlnet.use_simplified_condition_embedding: num_controlnet_images = controlnet_images.shape[2] controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * self.vae.config.scaling_factor self.controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) else: self.controlnet_images = controlnet_images # Define call parameters # perform classifier_free_guidance in default batch_size = 1 do_classifier_free_guidance = True device = self._execution_device # Encode input prompt self.text_embeddings = self._encode_prompt(self.input_config.new_prompt, device, 1, do_classifier_free_guidance, self.input_config.negative_prompt) # [uncond_embeddings, text_embeddings] [2, 77, 768] # Prepare latent variables noisy_latents = self.prepare_latents( batch_size, self.unet.config.in_channels, self.input_config.video_length, self.input_config.height, self.input_config.width, self.text_embeddings.dtype, device, generator, noisy_latents, ) self.motion_representation_dict = torch.load(self.motion_representation_path) self.motion_scale = self.input_config.motion_guidance_weight extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # save GPU memory # self.vae.to(device = "cpu") # self.text_encoder.to(device = "cpu") # torch.cuda.empty_cache() with self.progress_bar(total=self.input_config.inference_steps) as progress_bar: for step_index, step_t in enumerate(self.scheduler.timesteps): noisy_latents = self.single_step_video(noisy_latents, step_index, step_t, extra_step_kwargs) progress_bar.update() # decode latents for videos video = self.decode_latents(noisy_latents) return video def single_step_video(self, noisy_latents, step_index, step_t, extra_step_kwargs): down_block_additional_residuals = mid_block_additional_residual = None if self.add_controlnet: with torch.no_grad(): controlnet_cond_shape = list(self.controlnet_images.shape) controlnet_cond_shape[2] = noisy_latents.shape[2] controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_conditioning_mask_shape = list(controlnet_cond.shape) controlnet_conditioning_mask_shape[1] = 1 controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_image_index = self.input_config.image_index controlnet_cond[:,:,controlnet_image_index] = self.controlnet_images controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 down_block_additional_residuals, mid_block_additional_residual = self.controlnet( noisy_latents.expand(2,-1,-1,-1,-1), step_t, encoder_hidden_states=self.text_embeddings, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_conditioning_mask, conditioning_scale=self.input_config.controlnet_scale, guess_mode=False, return_dict=False, ) # Only require grad when need to compute the gradient for guidance if step_index < self.input_config.guidance_steps: down_block_additional_residuals_uncond = down_block_additional_residuals_cond = None mid_block_additional_residual_uncond = mid_block_additional_residual_cond = None if self.add_controlnet: down_block_additional_residuals_uncond = [tensor[[0],...].detach() for tensor in down_block_additional_residuals] down_block_additional_residuals_cond = [tensor[[1],...].detach() for tensor in down_block_additional_residuals] mid_block_additional_residual_uncond = mid_block_additional_residual[[0],...].detach() mid_block_additional_residual_cond = mid_block_additional_residual[[1],...].detach() control_latents = noisy_latents.clone().detach() control_latents.requires_grad = True control_latents = self.scheduler.scale_model_input(control_latents, step_t) noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) with torch.no_grad(): noise_pred_uncondition = self.unet(noisy_latents, step_t, encoder_hidden_states=self.text_embeddings[[0]], down_block_additional_residuals = down_block_additional_residuals_uncond, mid_block_additional_residual = mid_block_additional_residual_uncond,).sample.to(dtype=noisy_latents.dtype) noise_pred_condition = self.unet(control_latents, step_t, encoder_hidden_states=self.text_embeddings[[1]], down_block_additional_residuals = down_block_additional_residuals_cond, mid_block_additional_residual = mid_block_additional_residual_cond,).sample.to(dtype=noisy_latents.dtype) temp_attn_prob_control = self.get_temp_attn_prob() loss_motion = self.motion_scale * self.compute_temp_loss(temp_attn_prob_control,) if step_index < self.input_config.warm_up_steps: scale = (step_index+1)/self.input_config.warm_up_steps loss_motion = scale*loss_motion if step_index > self.input_config.guidance_steps-self.input_config.cool_up_steps: scale = (self.input_config.guidance_steps-step_index)/self.input_config.cool_up_steps loss_motion = scale*loss_motion gradient = torch.autograd.grad(loss_motion, control_latents, allow_unused=True)[0] # [1, 4, 16, 64, 64], assert gradient is not None, f"Step {step_index}: grad is None" noise_pred = noise_pred_condition + self.input_config.cfg_scale * (noise_pred_condition - noise_pred_uncondition) control_latents = self.scheduler.customized_step(noise_pred, step_index, control_latents, score=gradient.detach(), **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] return control_latents.detach() else: with torch.no_grad(): noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) noise_pred_group = self.unet( noisy_latents.expand(2,-1,-1,-1,-1), step_t, encoder_hidden_states=self.text_embeddings, down_block_additional_residuals = down_block_additional_residuals, mid_block_additional_residual = mid_block_additional_residual, ).sample.to(dtype=noisy_latents.dtype) noise_pred = noise_pred_group[[1]] + self.input_config.cfg_scale * (noise_pred_group[[1]] - noise_pred_group[[0]]) noisy_latents = self.scheduler.customized_step(noise_pred, step_index, noisy_latents, score=None, **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] return noisy_latents.detach() def get_temp_attn_prob(self,index_select=None): attn_prob_dic = {} for name, module in self.unet.named_modules(): module_name = type(module).__name__ if "VersatileAttention" in module_name and classify_blocks(self.input_config.motion_guidance_blocks, name): key = module.processor.key if index_select is not None: get_index = torch.repeat_interleave(torch.tensor(index_select), repeats=key.shape[0]//len(index_select)) index_all = torch.arange(key.shape[0]) index_picked = index_all[get_index.bool()] key = key[index_picked] key = module.reshape_heads_to_batch_dim(key).contiguous() query = module.processor.query if index_select is not None: query = query[index_picked] query = module.reshape_heads_to_batch_dim(query).contiguous() attention_probs = module.get_attention_scores(query, key, None) attention_probs = attention_probs.reshape(-1, module.heads,attention_probs.shape[1], attention_probs.shape[2]) attn_prob_dic[name] = attention_probs return attn_prob_dic @torch.no_grad() def schedule_customized_step( self, model_output: torch.FloatTensor, step_index: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, # Guidance parameters score=None, guidance_scale=1.0, indices=None, # [0] return_middle = False, ): if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # Support IF models if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None timestep = self.timesteps[step_index] # 1. get previous step value (=t-1) # prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps prev_timestep = self.timesteps[step_index+1] if step_index +1 = 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] if score is not None and return_middle: return pred_epsilon, alpha_prod_t, alpha_prod_t_prev, pred_original_sample # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf if score is not None and guidance_scale > 0.0: if indices is not None: # import pdb; pdb.set_trace() assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape" pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score else: assert pred_epsilon.shape == score.shape pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return prev_sample, pred_original_sample, alpha_prod_t_prev # equally to the function of schedule_customized_step (details refer to the discussion in https://github.com/LPengYang/MotionClone/issues/22) @torch.no_grad() def schedule_customized_step_candidate( self, model_output: torch.FloatTensor, step_index: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, # Guidance parameters score=None, guidance_scale=1.0, indices=None, # [0] return_middle = False, ): if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # Support IF models if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None timestep = self.timesteps[step_index] # 1. get previous step value (=t-1) # prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps prev_timestep = self.timesteps[step_index+1] if step_index +1 = 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t if self.config.prediction_type == "epsilon": if score is not None: alpha_t = alpha_prod_t/alpha_prod_t_prev model_output = model_output + guidance_scale/((1/alpha_t)**0.5-1)*score pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output # print("validation") else: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] if score is not None and return_middle: return pred_epsilon, alpha_prod_t, alpha_prod_t_prev, pred_original_sample pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return prev_sample, pred_original_sample, alpha_prod_t_prev def schedule_set_timesteps(self, num_inference_steps: int, guidance_steps: int = 0, guiduance_scale: float = 0.0, device: Union[str, torch.device] = None,timestep_spacing_type= "uneven"): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # assign more steps in early denoising stage for motion guidance if timestep_spacing_type == "uneven": timesteps_guidance = ( np.linspace(int((1-guiduance_scale)*self.config.num_train_timesteps), self.config.num_train_timesteps - 1, guidance_steps) .round()[::-1] .copy() .astype(np.int64) ) timesteps_vanilla = ( np.linspace(0, int((1-guiduance_scale)*self.config.num_train_timesteps) - 1, num_inference_steps-guidance_steps) .round()[::-1] .copy() .astype(np.int64) ) timesteps = np.concatenate((timesteps_guidance, timesteps_vanilla)) # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 elif timestep_spacing_type == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif timestep_spacing_type == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif timestep_spacing_type == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{timestep_spacing_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) @dataclass class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor def unet_customized_forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, # support controlnet down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, only_motion_feature: bool = False, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # pre-process sample = self.conv_in(sample) # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples # support controlnet down_block_res_samples = list(down_block_res_samples) if down_block_additional_residuals is not None: for i, down_block_additional_residual in enumerate(down_block_additional_residuals): if down_block_additional_residual.dim() == 4: # boardcast down_block_additional_residual = down_block_additional_residual.unsqueeze(2) down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual # mid sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) # support controlnet if mid_block_additional_residual is not None: if mid_block_additional_residual.dim() == 4: # boardcast mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) sample = sample + mid_block_additional_residual # up for i, upsample_block in enumerate(self.up_blocks): if i<= int(self.input_config.motion_guidance_blocks[-1].split(".")[-1]): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) else: if only_motion_feature: return 0 with torch.no_grad(): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) ================================================ FILE: motionclone/utils/util.py ================================================ import hashlib import io import re import os import imageio import numpy as np from typing import Union import cv2 import numpy as np import requests import random import torch import PIL.Image import PIL.ImageOps from PIL import Image from typing import Callable, Union import torch import torchvision import torch.distributed as dist import torch.nn.functional as F import decord decord.bridge.set_bridge('torch') from PIL import Image, ImageOps from safetensors import safe_open # from tqdm import tqdm from einops import rearrange from motionclone.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint,convert_ldm_clip_checkpoint_concise from motionclone.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora from huggingface_hub import snapshot_download # from transformers import ( # AutoFeatureExtractor, # BertTokenizerFast, # CLIPImageProcessor, # CLIPTextConfig, # CLIPTextModel, # CLIPTextModelWithProjection, # CLIPTokenizer, # CLIPVisionConfig, # CLIPVisionModelWithProjection, # ) MOTION_MODULES = [ "mm_sd_v14.ckpt", "mm_sd_v15.ckpt", "mm_sd_v15_v2.ckpt", "v3_sd15_mm.ckpt", ] ADAPTERS = [ # "mm_sd_v14.ckpt", # "mm_sd_v15.ckpt", # "mm_sd_v15_v2.ckpt", # "mm_sdxl_v10_beta.ckpt", "v2_lora_PanLeft.ckpt", "v2_lora_PanRight.ckpt", "v2_lora_RollingAnticlockwise.ckpt", "v2_lora_RollingClockwise.ckpt", "v2_lora_TiltDown.ckpt", "v2_lora_TiltUp.ckpt", "v2_lora_ZoomIn.ckpt", "v2_lora_ZoomOut.ckpt", "v3_sd15_adapter.ckpt", # "v3_sd15_mm.ckpt", "v3_sd15_sparsectrl_rgb.ckpt", "v3_sd15_sparsectrl_scribble.ckpt", ] BACKUP_DREAMBOOTH_MODELS = [ "realisticVisionV60B1_v51VAE.safetensors", "majicmixRealistic_v4.safetensors", "leosamsFilmgirlUltra_velvia20Lora.safetensors", "toonyou_beta3.safetensors", "majicmixRealistic_v5Preview.safetensors", "rcnzCartoon3d_v10.safetensors", "lyriel_v16.safetensors", "leosamsHelloworldXL_filmGrain20.safetensors", "TUSUN.safetensors", ] def zero_rank_print(s): if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) def auto_download(local_path, is_dreambooth_lora=False): hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff" folder, filename = os.path.split(local_path) if not os.path.exists(local_path): print(f"local file {local_path} does not exist. trying to download from {hf_repo}") if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}" else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}" folder = "." if folder == "" else folder os.makedirs(folder, exist_ok=True) snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename]) def load_weights( animation_pipeline, # motion module motion_module_path = "", motion_module_lora_configs = [], # domain adapter adapter_lora_path = "", adapter_lora_scale = 1.0, # image layers dreambooth_model_path = "", lora_model_path = "", lora_alpha = 0.8, ): # motion module unet_state_dict = {} if motion_module_path != "": print(f"load motion module from {motion_module_path}") motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) unet_state_dict.pop("animatediff_config", "") missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) # assert len(unexpected) == 0 del unet_state_dict # base model if dreambooth_model_path != "": print(f"load dreambooth model from {dreambooth_model_path}") if dreambooth_model_path.endswith(".safetensors"): # import pdb; pdb.set_trace() dreambooth_state_dict = {} # import safetensors # dreambooth_state_dict = safetensors.torch.load_file(dreambooth_model_path) # import pdb; pdb.set_trace() with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) # import pdb; pdb.set_trace() elif dreambooth_model_path.endswith(".ckpt"): dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") # 1. vae converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) # 2. unet converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) # 3. text_model # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_concise(dreambooth_state_dict) animation_pipeline.text_encoder.load_state_dict(converted_text_encoder_checkpoint, strict=True) del dreambooth_state_dict, converted_vae_checkpoint, converted_unet_checkpoint, converted_text_encoder_checkpoint # clip_config_name = "models/clip-vit-large-patch14" # clip_config = CLIPTextConfig.from_pretrained(clip_config_name, local_files_only=True) # text_model = CLIPTextModel(clip_config) # keys = list(dreambooth_state_dict.keys()) # text_model_dict = {} # for key in keys: # if key.startswith("cond_stage_model.transformer"): # text_model_dict[key[len("cond_stage_model.transformer.") :]] = dreambooth_state_dict[key] # text_model.load_state_dict(text_model_dict) # animation_pipeline.text_encoder = text_model.to(dtype=animation_pipeline.unet.dtype) # # import pdb; pdb.set_trace() # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) # del dreambooth_state_dict # lora layers if lora_model_path != "": print(f"load lora model from {lora_model_path}") assert lora_model_path.endswith(".safetensors") lora_state_dict = {} with safe_open(lora_model_path, framework="pt", device="cpu") as f: for key in f.keys(): lora_state_dict[key] = f.get_tensor(key) animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) del lora_state_dict # domain adapter lora if adapter_lora_path != "": print(f"load domain lora from {adapter_lora_path}") domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict domain_lora_state_dict.pop("animatediff_config", "") animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) # motion module lora for motion_module_lora_config in motion_module_lora_configs: path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] print(f"load motion LoRA from {path}") motion_lora_state_dict = torch.load(path, map_location="cpu") motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict motion_lora_state_dict.pop("animatediff_config", "") animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) return animation_pipeline def video_preprocess(video_path, height, width, video_length, duration=None, sample_start_idx=0,): video_name = video_path.split('/')[-1].split('.')[0] vr = decord.VideoReader(video_path) fps = vr.get_avg_fps() if duration is None: # 读取整个视频 total_frames = len(vr) else: # 根据给定的时长(秒)计算帧数 total_frames = int(fps * duration) total_frames = min(total_frames, len(vr)) # 确保不超过视频总长度 sample_index = np.linspace(0, total_frames - 1, video_length, dtype=int) print(total_frames,sample_index) video = vr.get_batch(sample_index) video = rearrange(video, "f h w c -> f c h w") video = F.interpolate(video, size=(height, width), mode="bilinear", align_corners=True) # video_sample = rearrange(video, "(b f) c h w -> b f h w c", f=video_length) # imageio.mimwrite(f"processed_videos/sample_{video_name}.mp4", video_sample[0], fps=8, quality=9) video = video / 127.5 - 1.0 return video def set_nested_item(dataDict, mapList, value): """Set item in nested dictionary""" """ Example: the mapList contains the name of each key ['injection','self-attn'] this method will change the content in dataDict['injection']['self-attn'] with value """ for k in mapList[:-1]: dataDict = dataDict[k] dataDict[mapList[-1]] = value def merge_sweep_config(base_config, update): """Merge the updated parameters into the base config""" if base_config is None: raise ValueError("Base config is None") if update is None: raise ValueError("Update config is None") for key in update.keys(): map_list = key.split("--") set_nested_item(base_config, map_list, update[key]) return base_config # Adapt from https://github.com/castorini/daam def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): merge_idxs = [] tokens = tokenizer.tokenize(prompt.lower()) if word_idx is None: word = word.lower() search_tokens = tokenizer.tokenize(word) start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens] for indice in start_indices: merge_idxs += [i + indice for i in range(0, len(search_tokens))] if not merge_idxs: raise Exception(f'Search word {word} not found in prompt!') else: merge_idxs.append(word_idx) return [x + 1 for x in merge_idxs], word_idx # Offset by 1. def extract_data(input_string: str) -> list: print("input_string:", input_string) """ Extract data from a string pattern where contents in () are separated by ; The first item in each () is considered as 'ref' and the rest as 'gen'. Args: - input_string (str): The input string pattern. Returns: - list: A list of dictionaries containing 'ref' and 'gen'. """ pattern = r'\(([^)]+)\)' matches = re.findall(pattern, input_string) data = [] for match in matches: parts = [x.strip() for x in match.split(';')] ref = parts[0].strip() gen = parts[1].strip() data.append({'ref': ref, 'gen': gen}) return data def generate_hash_key(image, prompt=""): """ Generate a hash key for the given image and prompt. """ byte_array = io.BytesIO() image.save(byte_array, format='JPEG') # Get byte data image_byte_data = byte_array.getvalue() # Combine image byte data and prompt byte data combined_data = image_byte_data + prompt.encode('utf-8') sha256 = hashlib.sha256() sha256.update(combined_data) return sha256.hexdigest() def save_data(data, folder_path, key): """ Save data to a file, using key as the file name """ if not os.path.exists(folder_path): os.makedirs(folder_path) file_path = os.path.join(folder_path, f"{key}.pt") torch.save(data, file_path) def get_data(folder_path, key): """ Get data from a file, using key as the file name :param folder_path: :param key: :return: """ file_path = os.path.join(folder_path, f"{key}.pt") if os.path.exists(file_path): return torch.load(file_path) else: return None def PILtoTensor(data: Image.Image) -> torch.Tensor: return torch.tensor(np.array(data)).permute(2, 0, 1).unsqueeze(0).float() def TensorToPIL(data: torch.Tensor) -> Image.Image: return Image.fromarray(data.squeeze().permute(1, 2, 0).numpy().astype(np.uint8)) # Adapt from https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/utils/loading_utils.py#L9 def load_image( image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None ) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): A conversion method to apply to the image after loading it. When set to `None` the image will be converted "RGB". Returns: `PIL.Image.Image`: A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): image = PIL.Image.open(requests.get(image, stream=True).raw) elif os.path.isfile(image): image = PIL.Image.open(image) else: raise ValueError( f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) if convert_method is not None: image = convert_method(image) else: image = image.convert("RGB") return image # Take from huggingface/diffusers def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg def _in_step(config, step): in_step = False try: start_step = config.start_step end_step = config.end_step if start_step <= step < end_step: in_step = True except: in_step = False return in_step def classify_blocks(block_list, name): is_correct_block = False for block in block_list: if block in name: is_correct_block = True break return is_correct_block def set_all_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True ================================================ FILE: motionclone/utils/xformer_attention.py ================================================ import math from typing import Optional, Callable import xformers from omegaconf import OmegaConf import yaml from .util import classify_blocks def identify_blocks(block_list, name): block_name = None for block in block_list: if block in name: block_name = block break return block_name class MySelfAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn, hidden_states, query, key, value, attention_mask): # self.attn = attn self.key = key self.query = query # self.value = value # self.attention_mask = attention_mask # self.hidden_state = hidden_states.detach() # return hidden_states def record_qkv(self, attn, hidden_states, query, key, value, attention_mask): # self.attn = attn self.key = key self.query = query # self.value = value # # self.attention_mask = attention_mask # self.hidden_state = hidden_states.detach() # # import pdb; pdb.set_trace() def record_attn_mask(self, attn, hidden_states, query, key, value, attention_mask): self.attn = attn self.attention_mask = attention_mask def prep_unet_attention(unet,motion_gudiance_blocks): # replace the fwd function for name, module in unet.named_modules(): module_name = type(module).__name__ if "VersatileAttention" in module_name and classify_blocks(motion_gudiance_blocks, name): # the temporary attention in guidance blocks module.set_processor(MySelfAttnProcessor()) # print(module_name) return unet def get_self_attn_feat(unet, injection_config, config): hidden_state_dict = dict() query_dict = dict() key_dict = dict() value_dict = dict() for name, module in unet.named_modules(): module_name = type(module).__name__ if "CrossAttention" in module_name and 'attn1' in name and classify_blocks(injection_config.blocks, name=name): res = int(math.sqrt(module.processor.hidden_state.shape[1])) # import pdb; pdb.set_trace() bs = module.processor.hidden_state.shape[0] # 20 * 16 = 320 # block_name = identify_blocks(injection_config.blocks, name=name) # block_id = int(block_name.split('.')[-1]) # h = config.H // (32 * block_id) # w = config.W // (32 * block_id) hidden_state_dict[name] = module.processor.hidden_state.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) res = int(math.sqrt(module.processor.query.shape[1])) query_dict[name] = module.processor.query.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) key_dict[name] = module.processor.key.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) value_dict[name] = module.processor.value.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) # import pdb; pdb.set_trace() # import pdb; pdb.set_trace() return hidden_state_dict, query_dict, key_dict, value_dict def clean_attn_buffer(unet): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "Attention" and 'attn' in name: if 'injection_config' in module.processor.__dict__.keys(): module.processor.injection_config = None if 'injection_mask' in module.processor.__dict__.keys(): module.processor.injection_mask = None if 'obj_index' in module.processor.__dict__.keys(): module.processor.obj_index = None if 'pca_weight' in module.processor.__dict__.keys(): module.processor.pca_weight = None if 'pca_weight_changed' in module.processor.__dict__.keys(): module.processor.pca_weight_changed = None if 'pca_info' in module.processor.__dict__.keys(): module.processor.pca_info = None if 'step' in module.processor.__dict__.keys(): module.processor.step = None ================================================ FILE: t2v_video_sample.py ================================================ import argparse from omegaconf import OmegaConf import torch from diffusers import AutoencoderKL, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer from motionclone.models.unet import UNet3DConditionModel from motionclone.pipelines.pipeline_animation import AnimationPipeline from motionclone.utils.util import load_weights from diffusers.utils.import_utils import is_xformers_available from motionclone.utils.motionclone_functions import * import json from motionclone.utils.xformer_attention import * def main(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0)) config = OmegaConf.load(args.inference_config) adopted_dtype = torch.float16 device = "cuda" set_all_seed(42) tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype) vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype) config.width = config.get("W", args.W) config.height = config.get("H", args.H) config.video_length = config.get("L", args.L) if not os.path.exists(args.generated_videos_save_dir): os.makedirs(args.generated_videos_save_dir) OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json")) model_config = OmegaConf.load(config.get("model_config", "")) unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype) # set xformers if is_xformers_available() and (not args.without_xformers): unet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=None, scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)), ).to(device) pipeline = load_weights( pipeline, # motion module motion_module_path = config.get("motion_module", ""), dreambooth_model_path = config.get("dreambooth_path", ""), ).to(device) pipeline.text_encoder.to(dtype=adopted_dtype) # load customized functions from motionclone_functions pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler) pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler) pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet) pipeline.sample_video = sample_video.__get__(pipeline) pipeline.single_step_video = single_step_video.__get__(pipeline) pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline) pipeline.add_noise = add_noise.__get__(pipeline) pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline) pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline) for param in pipeline.unet.parameters(): param.requires_grad = False pipeline.input_config, pipeline.unet.input_config = config, config pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks) pipeline.unet = prep_unet_conv(pipeline.unet) pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven") # pipeline.scheduler.customized_set_timesteps(config.inference_steps,device=device,timestep_spacing_type = "linspace") with open(args.examples, 'r') as files: for line in files: # prepare infor of each case example_infor = json.loads(line) config.video_path = example_infor["video_path"] config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "") pipeline.input_config, pipeline.unet.input_config = config, config # update config # perform motion representation extraction seed_motion = example_infor.get("seed", args.default_seed) generator = torch.Generator(device=pipeline.device) generator.manual_seed(seed_motion) if not os.path.exists(args.motion_representation_save_dir): os.makedirs(args.motion_representation_save_dir) motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt') pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path) # perform video generation seed = seed_motion # can assign other seed here generator = torch.Generator(device=pipeline.device) generator.manual_seed(seed) pipeline.input_config.seed = seed videos = pipeline.sample_video(generator = generator,) videos = rearrange(videos, "b c f h w -> b f h w c") save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4') videos_uint8 = (videos[0] * 255).astype(np.uint8) imageio.mimwrite(save_path, videos_uint8, fps=8) print(save_path,"is done") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",) parser.add_argument("--inference_config", type=str, default="configs/t2v_camera.yaml") parser.add_argument("--examples", type=str, default="configs/t2v_camera.jsonl") parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/") parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos") parser.add_argument("--visible_gpu", type=str, default=None) parser.add_argument("--default-seed", type=int, default=2025) parser.add_argument("--L", type=int, default=16) parser.add_argument("--W", type=int, default=512) parser.add_argument("--H", type=int, default=512) parser.add_argument("--without-xformers", action="store_true") args = parser.parse_args() main(args)