Showing preview only (432K chars total). Download the full file or copy to clipboard to get everything.
Repository: magic-research/magic-animate
Branch: main
Commit: d2bc3bc3c9cc
Files: 31
Total size: 417.0 KB
Directory structure:
gitextract_ig9mhcve/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── inference/
│ │ └── inference.yaml
│ └── prompts/
│ └── animation.yaml
├── demo/
│ ├── animate.py
│ ├── animate_dist.py
│ ├── gradio_animate.py
│ └── gradio_animate_dist.py
├── environment.yaml
├── magicanimate/
│ ├── models/
│ │ ├── appearance_encoder.py
│ │ ├── attention.py
│ │ ├── controlnet.py
│ │ ├── embeddings.py
│ │ ├── motion_module.py
│ │ ├── mutual_self_attention.py
│ │ ├── orig_attention.py
│ │ ├── resnet.py
│ │ ├── stable_diffusion_controlnet_reference.py
│ │ ├── unet.py
│ │ ├── unet_3d_blocks.py
│ │ └── unet_controlnet.py
│ ├── pipelines/
│ │ ├── animation.py
│ │ ├── context.py
│ │ └── pipeline_animation.py
│ └── utils/
│ ├── dist_tools.py
│ ├── util.py
│ └── videoreader.py
├── requirements.txt
└── scripts/
├── animate.sh
└── animate_dist.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__
.vscode
samples
xformers
src
third_party
backup
pretrained_models
*.nfs*
./*.png
./*.mp4
demo/tmp
demo/outputs
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright 2023 MagicAnimate Team All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
<!-- # magic-edit.github.io -->
<p align="center">
<h2 align="center">MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h2>
<p align="center">
<a href="https://scholar.google.com/citations?user=-4iADzMAAAAJ&hl=en"><strong>Zhongcong Xu</strong></a>
·
<a href="http://jeff95.me/"><strong>Jianfeng Zhang</strong></a>
·
<a href="https://scholar.google.com.sg/citations?user=8gm-CYYAAAAJ&hl=en"><strong>Jun Hao Liew</strong></a>
·
<a href="https://hanshuyan.github.io/"><strong>Hanshu Yan</strong></a>
·
<a href="https://scholar.google.com/citations?user=stQQf7wAAAAJ&hl=en"><strong>Jia-Wei Liu</strong></a>
·
<a href="https://zhangchenxu528.github.io/"><strong>Chenxu Zhang</strong></a>
·
<a href="https://sites.google.com/site/jshfeng/home"><strong>Jiashi Feng</strong></a>
·
<a href="https://sites.google.com/view/showlab"><strong>Mike Zheng Shou</strong></a>
<br>
<br>
<a href="https://arxiv.org/abs/2311.16498"><img src='https://img.shields.io/badge/arXiv-MagicAnimate-red' alt='Paper PDF'></a>
<a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>
<a href='https://huggingface.co/spaces/zcxu-eric/magicanimate'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
<br>
<b>National University of Singapore | ByteDance</b>
</p>
<table align="center">
<tr>
<td>
<img src="assets/teaser/t4.gif">
</td>
<td>
<img src="assets/teaser/t2.gif">
</td>
</tr>
</table>
## 📢 News
* **[2023.12.4]** Release inference code and gradio demo. We are working to improve MagicAnimate, stay tuned!
* **[2023.11.23]** Release MagicAnimate paper and project page.
## 🏃♂️ Getting Started
Download the pretrained base models for [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [MSE-finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse).
Download our MagicAnimate [checkpoints](https://huggingface.co/zcxu-eric/MagicAnimate).
Please follow the huggingface download instructions to download the above models and checkpoints, `git lfs` is recommended.
Place the based models and checkpoints as follows:
```bash
magic-animate
|----pretrained_models
|----MagicAnimate
|----appearance_encoder
|----diffusion_pytorch_model.safetensors
|----config.json
|----densepose_controlnet
|----diffusion_pytorch_model.safetensors
|----config.json
|----temporal_attention
|----temporal_attention.ckpt
|----sd-vae-ft-mse
|----config.json
|----diffusion_pytorch_model.safetensors
|----stable-diffusion-v1-5
|----scheduler
|----scheduler_config.json
|----text_encoder
|----config.json
|----pytorch_model.bin
|----tokenizer (all)
|----unet
|----diffusion_pytorch_model.bin
|----config.json
|----v1-5-pruned-emaonly.safetensors
|----...
```
## ⚒️ Installation
prerequisites: `python>=3.8`, `CUDA>=11.3`, and `ffmpeg`.
Install with `conda`:
```bash
conda env create -f environment.yaml
conda activate manimate
```
or `pip`:
```bash
pip3 install -r requirements.txt
```
## 💃 Inference
Run inference on single GPU:
```bash
bash scripts/animate.sh
```
Run inference with multiple GPUs:
```bash
bash scripts/animate_dist.sh
```
## 🎨 Gradio Demo
#### Online Gradio Demo:
Try our [online gradio demo](https://huggingface.co/spaces/zcxu-eric/magicanimate) quickly.
#### Local Gradio Demo:
Launch local gradio demo on single GPU:
```bash
python3 -m demo.gradio_animate
```
Launch local gradio demo if you have multiple GPUs:
```bash
python3 -m demo.gradio_animate_dist
```
Then open gradio demo in local browser.
## 🙏 Acknowledgements
We would like to thank [AK(@_akhaliq)](https://twitter.com/_akhaliq?lang=en) and huggingface team for the help of setting up oneline gradio demo.
## 🎓 Citation
If you find this codebase useful for your research, please use the following entry.
```BibTeX
@inproceedings{xu2023magicanimate,
author = {Xu, Zhongcong and Zhang, Jianfeng and Liew, Jun Hao and Yan, Hanshu and Liu, Jia-Wei and Zhang, Chenxu and Feng, Jiashi and Shou, Mike Zheng},
title = {MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model},
booktitle = {arXiv},
year = {2023}
}
```
================================================
FILE: configs/inference/inference.yaml
================================================
unet_additional_kwargs:
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
noise_scheduler_kwargs:
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
================================================
FILE: configs/prompts/animation.yaml
================================================
pretrained_model_path: "pretrained_models/stable-diffusion-v1-5"
pretrained_vae_path: "pretrained_models/sd-vae-ft-mse"
pretrained_controlnet_path: "pretrained_models/MagicAnimate/densepose_controlnet"
pretrained_appearance_encoder_path: "pretrained_models/MagicAnimate/appearance_encoder"
pretrained_unet_path: ""
motion_module: "pretrained_models/MagicAnimate/temporal_attention/temporal_attention.ckpt"
savename: null
fusion_blocks: "midup"
seed: [1]
steps: 25
guidance_scale: 7.5
source_image:
- "inputs/applications/source_image/monalisa.png"
- "inputs/applications/source_image/demo4.png"
- "inputs/applications/source_image/dalle2.jpeg"
- "inputs/applications/source_image/dalle8.jpeg"
- "inputs/applications/source_image/multi1_source.png"
video_path:
- "inputs/applications/driving/densepose/running.mp4"
- "inputs/applications/driving/densepose/demo4.mp4"
- "inputs/applications/driving/densepose/running2.mp4"
- "inputs/applications/driving/densepose/dancing2.mp4"
- "inputs/applications/driving/densepose/multi_dancing.mp4"
inference_config: "configs/inference/inference.yaml"
size: 512
L: 16
S: 1
I: 0
clip: 0
offset: 0
max_length: null
video_type: "condition"
invert_video: false
save_individual_videos: false
================================================
FILE: demo/animate.py
================================================
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import argparse
import argparse
import datetime
import inspect
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from collections import OrderedDict
import torch
from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from magicanimate.models.unet_controlnet import UNet3DConditionModel
from magicanimate.models.controlnet import ControlNetModel
from magicanimate.models.appearance_encoder import AppearanceEncoderModel
from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
from magicanimate.pipelines.pipeline_animation import AnimationPipeline
from magicanimate.utils.util import save_videos_grid
from accelerate.utils import set_seed
from magicanimate.utils.videoreader import VideoReader
from einops import rearrange, repeat
import csv, pdb, glob
from safetensors import safe_open
import math
from pathlib import Path
class MagicAnimate():
def __init__(self, config="configs/prompts/animation.yaml") -> None:
print("Initializing MagicAnimate Pipeline...")
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
config = OmegaConf.load(config)
inference_config = OmegaConf.load(config.inference_config)
motion_module = config.motion_module
### >>> create animation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
if config.pretrained_unet_path:
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
else:
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda()
self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
if config.pretrained_vae_path is not None:
vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
else:
vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
### Load controlnet
controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
vae.to(torch.float16)
unet.to(torch.float16)
text_encoder.to(torch.float16)
controlnet.to(torch.float16)
self.appearance_encoder.to(torch.float16)
unet.enable_xformers_memory_efficient_attention()
self.appearance_encoder.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()
self.pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
# NOTE: UniPCMultistepScheduler
).to("cuda")
# 1. unet ckpt
# 1.1 motion module
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
try:
# extra steps for self-trained models
state_dict = OrderedDict()
for key in motion_module_state_dict.keys():
if key.startswith("module."):
_key = key.split("module.")[-1]
state_dict[_key] = motion_module_state_dict[key]
else:
state_dict[key] = motion_module_state_dict[key]
motion_module_state_dict = state_dict
del state_dict
missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
except:
_tmp_ = OrderedDict()
for key in motion_module_state_dict.keys():
if "motion_modules" in key:
if key.startswith("unet."):
_key = key.split('unet.')[-1]
_tmp_[_key] = motion_module_state_dict[key]
else:
_tmp_[key] = motion_module_state_dict[key]
missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
assert len(unexpected) == 0
del _tmp_
del motion_module_state_dict
self.pipeline.to("cuda")
self.L = config.L
print("Initialization Done!")
def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
prompt = n_prompt = ""
random_seed = int(random_seed)
step = int(step)
guidance_scale = float(guidance_scale)
samples_per_video = []
# manually set random seed for reproduction
if random_seed != -1:
torch.manual_seed(random_seed)
set_seed(random_seed)
else:
torch.seed()
if motion_sequence.endswith('.mp4'):
control = VideoReader(motion_sequence).read()
if control[0].shape[0] != size:
control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
control = np.array(control)
if source_image.shape[0] != size:
source_image = np.array(Image.fromarray(source_image).resize((size, size)))
H, W, C = source_image.shape
init_latents = None
original_length = control.shape[0]
if control.shape[0] % self.L > 0:
control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
generator = torch.Generator(device=torch.device("cuda:0"))
generator.manual_seed(torch.initial_seed())
sample = self.pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = step,
guidance_scale = guidance_scale,
width = W,
height = H,
video_length = len(control),
controlnet_condition = control,
init_latents = init_latents,
generator = generator,
appearance_encoder = self.appearance_encoder,
reference_control_writer = self.reference_control_writer,
reference_control_reader = self.reference_control_reader,
source_image = source_image,
).videos
source_images = np.array([source_image] * original_length)
source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
samples_per_video.append(source_images)
control = control / 255.0
control = rearrange(control, "t h w c -> 1 c t h w")
control = torch.from_numpy(control)
samples_per_video.append(control[:, :, :original_length])
samples_per_video.append(sample[:, :, :original_length])
samples_per_video = torch.cat(samples_per_video)
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"demo/outputs"
animation_path = f"{savedir}/{time_str}.mp4"
os.makedirs(savedir, exist_ok=True)
save_videos_grid(samples_per_video, animation_path)
return animation_path
================================================
FILE: demo/animate_dist.py
================================================
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import argparse
import argparse
import datetime
import inspect
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from collections import OrderedDict
import torch
import random
from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from magicanimate.models.unet_controlnet import UNet3DConditionModel
from magicanimate.models.controlnet import ControlNetModel
from magicanimate.models.appearance_encoder import AppearanceEncoderModel
from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
from magicanimate.pipelines.pipeline_animation import AnimationPipeline
from magicanimate.utils.util import save_videos_grid
from magicanimate.utils.dist_tools import distributed_init
from accelerate.utils import set_seed
from magicanimate.utils.videoreader import VideoReader
from einops import rearrange
animator = None
class MagicAnimate():
def __init__(self, args) -> None:
config=args.config
device = torch.device(f"cuda:{args.rank}")
print("Initializing MagicAnimate Pipeline...")
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
config = OmegaConf.load(config)
inference_config = OmegaConf.load(config.inference_config)
motion_module = config.motion_module
### >>> create animation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
if config.pretrained_unet_path:
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
else:
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device)
self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
if config.pretrained_vae_path is not None:
vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
else:
vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
### Load controlnet
controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
vae.to(torch.float16)
unet.to(torch.float16)
text_encoder.to(torch.float16)
controlnet.to(torch.float16)
self.appearance_encoder.to(torch.float16)
unet.enable_xformers_memory_efficient_attention()
self.appearance_encoder.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()
self.pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
# NOTE: UniPCMultistepScheduler
)
# 1. unet ckpt
# 1.1 motion module
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
try:
# extra steps for self-trained models
state_dict = OrderedDict()
for key in motion_module_state_dict.keys():
if key.startswith("module."):
_key = key.split("module.")[-1]
state_dict[_key] = motion_module_state_dict[key]
else:
state_dict[key] = motion_module_state_dict[key]
motion_module_state_dict = state_dict
del state_dict
missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
except:
_tmp_ = OrderedDict()
for key in motion_module_state_dict.keys():
if "motion_modules" in key:
if key.startswith("unet."):
_key = key.split('unet.')[-1]
_tmp_[_key] = motion_module_state_dict[key]
else:
_tmp_[key] = motion_module_state_dict[key]
missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
assert len(unexpected) == 0
del _tmp_
del motion_module_state_dict
self.pipeline.to(device)
self.L = config.L
print("Initialization Done!")
dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist}
self.predict(args.reference_image, args.motion_sequence, args.random_seed, args.step, args.guidance_scale, args.save_path, dist_kwargs)
def predict(self, source_image, motion_sequence, random_seed, step, guidance_scale, save_path, dist_kwargs, size=512):
prompt = n_prompt = ""
samples_per_video = []
# manually set random seed for reproduction
if random_seed != -1:
torch.manual_seed(random_seed)
set_seed(random_seed)
else:
torch.seed()
if motion_sequence.endswith('.mp4'):
control = VideoReader(motion_sequence).read()
if control[0].shape[0] != size:
control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
control = np.array(control)
if not isinstance(source_image, np.ndarray):
source_image = np.array(Image.open(source_image))
if source_image.shape[0] != size:
source_image = np.array(Image.fromarray(source_image).resize((size, size)))
H, W, C = source_image.shape
init_latents = None
original_length = control.shape[0]
if control.shape[0] % self.L > 0:
control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
generator = torch.Generator(device=torch.device("cuda:0"))
generator.manual_seed(torch.initial_seed())
sample = self.pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = step,
guidance_scale = guidance_scale,
width = W,
height = H,
video_length = len(control),
controlnet_condition = control,
init_latents = init_latents,
generator = generator,
appearance_encoder = self.appearance_encoder,
reference_control_writer = self.reference_control_writer,
reference_control_reader = self.reference_control_reader,
source_image = source_image,
**dist_kwargs,
).videos
if dist_kwargs.get('rank', 0) == 0:
source_images = np.array([source_image] * original_length)
source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
samples_per_video.append(source_images)
control = control / 255.0
control = rearrange(control, "t h w c -> 1 c t h w")
control = torch.from_numpy(control)
samples_per_video.append(control[:, :, :original_length])
samples_per_video.append(sample[:, :, :original_length])
samples_per_video = torch.cat(samples_per_video)
save_videos_grid(samples_per_video, save_path)
def distributed_main(device_id, args):
args.rank = device_id
args.device_id = device_id
if torch.cuda.is_available():
torch.cuda.set_device(args.device_id)
torch.cuda.init()
distributed_init(args)
MagicAnimate(args)
def run(args):
if args.dist:
args.world_size = max(1, torch.cuda.device_count())
assert args.world_size <= torch.cuda.device_count()
if args.world_size > 0 and torch.cuda.device_count() > 1:
port = random.randint(10000, 20000)
args.init_method = f"tcp://localhost:{port}"
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args,),
nprocs=args.world_size,
)
else:
MagicAnimate(args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/prompts/animation.yaml", required=False)
parser.add_argument("--dist", type=bool, default=True, required=False)
parser.add_argument("--rank", type=int, default=0, required=False)
parser.add_argument("--world_size", type=int, default=1, required=False)
parser.add_argument("--reference_image", type=str, default=None, required=True)
parser.add_argument("--motion_sequence", type=str, default=None, required=True)
parser.add_argument("--random_seed", type=int, default=1, required=False)
parser.add_argument("--step", type=int, default=25, required=False)
parser.add_argument("--guidance_scale", type=float, default=7.5, required=False)
parser.add_argument("--save_path", type=str, default=None, required=True)
args = parser.parse_args()
run(args)
================================================
FILE: demo/gradio_animate.py
================================================
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import argparse
import imageio
import numpy as np
import gradio as gr
from PIL import Image
from demo.animate import MagicAnimate
animator = MagicAnimate()
def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale):
return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale)
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/magic-research/magic-animate" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
</a>
<div>
<h1 >MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h1>
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href="https://arxiv.org/abs/2311.16498"><img src="https://img.shields.io/badge/Arxiv-2311.16498-red"></a>
<a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>
<a href='https://github.com/magic-research/magic-animate'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
</div>
</div>
</div>
""")
animation = gr.Video(format="mp4", label="Animation Results", autoplay=True)
with gr.Row():
reference_image = gr.Image(label="Reference Image")
motion_sequence = gr.Video(format="mp4", label="Motion Sequence")
with gr.Column():
random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1")
sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25")
guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5")
submit = gr.Button("Animate")
def read_video(video):
reader = imageio.get_reader(video)
fps = reader.get_meta_data()['fps']
return video
def read_image(image, size=512):
return np.array(Image.fromarray(image).resize((size, size)))
# when user uploads a new video
motion_sequence.upload(
read_video,
motion_sequence,
motion_sequence
)
# when `first_frame` is updated
reference_image.upload(
read_image,
reference_image,
reference_image
)
# when the `submit` button is clicked
submit.click(
animate,
[reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale],
animation
)
# Examples
gr.Markdown("## Examples")
gr.Examples(
examples=[
["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"],
["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"],
["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"],
["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"],
["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"],
],
inputs=[reference_image, motion_sequence],
outputs=animation,
)
demo.launch()
================================================
FILE: demo/gradio_animate_dist.py
================================================
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import argparse
import imageio
import os, datetime
import numpy as np
import gradio as gr
from PIL import Image
from subprocess import PIPE, run
os.makedirs("./demo/tmp", exist_ok=True)
savedir = f"demo/outputs"
os.makedirs(savedir, exist_ok=True)
def animate(reference_image, motion_sequence, seed, steps, guidance_scale):
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
animation_path = f"{savedir}/{time_str}.mp4"
save_path = "./demo/tmp/input_reference_image.png"
Image.fromarray(reference_image).save(save_path)
command = "python -m demo.animate_dist --reference_image {} --motion_sequence {} --random_seed {} --step {} --guidance_scale {} --save_path {}".format(
save_path,
motion_sequence,
seed,
steps,
guidance_scale,
animation_path
)
run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True)
return animation_path
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/magic-research/magic-animate" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
</a>
<div>
<h1 >MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h1>
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href="https://arxiv.org/abs/2311.16498"><img src="https://img.shields.io/badge/Arxiv-2311.16498-red"></a>
<a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>
<a href='https://github.com/magic-research/magic-animate'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
</div>
</div>
</div>
""")
animation = gr.Video(format="mp4", label="Animation Results", autoplay=True)
with gr.Row():
reference_image = gr.Image(label="Reference Image")
motion_sequence = gr.Video(format="mp4", label="Motion Sequence")
with gr.Column():
random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1")
sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25")
guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5")
submit = gr.Button("Animate")
def read_video(video, size=512):
size = int(size)
reader = imageio.get_reader(video)
# fps = reader.get_meta_data()['fps']
frames = []
for img in reader:
frames.append(np.array(Image.fromarray(img).resize((size, size))))
save_path = "./demo/tmp/input_motion_sequence.mp4"
imageio.mimwrite(save_path, frames, fps=25)
return save_path
def read_image(image, size=512):
img = np.array(Image.fromarray(image).resize((size, size)))
return img
# when user uploads a new video
motion_sequence.upload(
read_video,
motion_sequence,
motion_sequence
)
# when `first_frame` is updated
reference_image.upload(
read_image,
reference_image,
reference_image
)
# when the `submit` button is clicked
submit.click(
animate,
[reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale],
animation
)
# Examples
gr.Markdown("## Examples")
gr.Examples(
examples=[
["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"],
["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"],
["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"],
["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"],
["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"],
],
inputs=[reference_image, motion_sequence],
outputs=animation,
)
# demo.queue(max_size=10)
demo.launch()
================================================
FILE: environment.yaml
================================================
name: manimate
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- asttokens=2.2.1=pyhd8ed1ab_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
- ca-certificates=2023.7.22=hbcca054_0
- comm=0.1.4=pyhd8ed1ab_0
- debugpy=1.6.7=py38h6a678d5_0
- decorator=5.1.1=pyhd8ed1ab_0
- entrypoints=0.4=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- ipykernel=6.25.1=pyh71e2992_0
- ipython=8.12.0=pyh41d4057_0
- jedi=0.19.0=pyhd8ed1ab_0
- jupyter_client=7.3.4=pyhd8ed1ab_0
- jupyter_core=4.12.0=py38h578d9bd_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- openssl=1.1.1l=h7f98852_0
- packaging=23.1=pyhd8ed1ab_0
- parso=0.8.3=pyhd8ed1ab_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pip=23.2.1=py38h06a4308_0
- prompt-toolkit=3.0.39=pyha770c72_0
- prompt_toolkit=3.0.39=hd8ed1ab_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.16.1=pyhd8ed1ab_0
- python=3.8.5=h7579374_1
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pyzmq=25.1.0=py38h6a678d5_0
- readline=8.2=h5eee18b_0
- setuptools=68.0.0=py38h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.41.2=h5eee18b_0
- stack_data=0.6.2=pyhd8ed1ab_0
- tk=8.6.12=h1ccaba5_0
- tornado=6.1=py38h0a891b7_3
- traitlets=5.9.0=pyhd8ed1ab_0
- typing_extensions=4.7.1=pyha770c72_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- wheel=0.38.4=py38h06a4308_0
- xz=5.4.2=h5eee18b_0
- zeromq=4.3.4=h9c3ff4c_1
- zlib=1.2.13=h5eee18b_0
- pip:
- absl-py==1.4.0
- accelerate==0.22.0
- aiofiles==23.2.1
- aiohttp==3.8.5
- aiosignal==1.3.1
- altair==5.0.1
- annotated-types==0.5.0
- antlr4-python3-runtime==4.9.3
- anyio==3.7.1
- async-timeout==4.0.3
- attrs==23.1.0
- cachetools==5.3.1
- certifi==2023.7.22
- charset-normalizer==3.2.0
- click==8.1.7
- cmake==3.27.2
- contourpy==1.1.0
- cycler==0.11.0
- datasets==2.14.4
- dill==0.3.7
- einops==0.6.1
- exceptiongroup==1.1.3
- fastapi==0.103.0
- ffmpy==0.3.1
- filelock==3.12.2
- fonttools==4.42.1
- frozenlist==1.4.0
- fsspec==2023.6.0
- google-auth==2.22.0
- google-auth-oauthlib==1.0.0
- gradio==3.41.2
- gradio-client==0.5.0
- grpcio==1.57.0
- h11==0.14.0
- httpcore==0.17.3
- httpx==0.24.1
- huggingface-hub==0.16.4
- idna==3.4
- importlib-metadata==6.8.0
- importlib-resources==6.0.1
- jinja2==3.1.2
- joblib==1.3.2
- jsonschema==4.19.0
- jsonschema-specifications==2023.7.1
- kiwisolver==1.4.5
- lightning-utilities==0.9.0
- lit==16.0.6
- markdown==3.4.4
- markupsafe==2.1.3
- matplotlib==3.7.2
- mpmath==1.3.0
- multidict==6.0.4
- multiprocess==0.70.15
- networkx==3.1
- numpy==1.24.4
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-cupti-cu11==11.7.101
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- nvidia-cufft-cu11==10.9.0.58
- nvidia-curand-cu11==10.2.10.91
- nvidia-cusolver-cu11==11.4.0.1
- nvidia-cusparse-cu11==11.7.4.91
- nvidia-nccl-cu11==2.14.3
- nvidia-nvtx-cu11==11.7.91
- oauthlib==3.2.2
- omegaconf==2.3.0
- opencv-python==4.8.0.76
- orjson==3.9.5
- pandas==2.0.3
- pillow==9.5.0
- pkgutil-resolve-name==1.3.10
- protobuf==4.24.2
- psutil==5.9.5
- pyarrow==13.0.0
- pyasn1==0.5.0
- pyasn1-modules==0.3.0
- pydantic==2.3.0
- pydantic-core==2.6.3
- pydub==0.25.1
- pyparsing==3.0.9
- python-multipart==0.0.6
- pytorch-lightning==2.0.7
- pytz==2023.3
- pyyaml==6.0.1
- referencing==0.30.2
- regex==2023.8.8
- requests==2.31.0
- requests-oauthlib==1.3.1
- rpds-py==0.9.2
- rsa==4.9
- safetensors==0.3.3
- semantic-version==2.10.0
- sniffio==1.3.0
- starlette==0.27.0
- sympy==1.12
- tensorboard==2.14.0
- tensorboard-data-server==0.7.1
- tokenizers==0.13.3
- toolz==0.12.0
- torchmetrics==1.1.0
- tqdm==4.66.1
- transformers==4.32.0
- triton==2.0.0
- tzdata==2023.3
- urllib3==1.26.16
- uvicorn==0.23.2
- websockets==11.0.3
- werkzeug==2.3.7
- xxhash==3.3.0
- yarl==1.9.2
- zipp==3.16.2
- decord
- imageio==2.9.0
- imageio-ffmpeg==0.4.3
- timm
- scipy
- scikit-image
- av
- imgaug
- lpips
- ffmpeg-python
- torch==2.0.1
- torchvision==0.15.2
- xformers==0.0.22
- diffusers==0.21.4
prefix: /home/tiger/miniconda3/envs/manimate
================================================
FILE: magicanimate/models/appearance_encoder.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.models.embeddings import (
GaussianFourierProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
get_down_block,
get_up_block,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Identity(torch.nn.Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 20])
"""
def __init__(self, scale=None, *args, **kwargs) -> None:
super(Identity, self).__init__()
def forward(self, input, *args, **kwargs):
return input
class _LoRACompatibleLinear(nn.Module):
"""
A Linear layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer
def _fuse_lora(self):
pass
def _unfuse_lora(self):
pass
def forward(self, hidden_states, scale=None, lora_scale: int = 1):
return hidden_states
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
The output of [`UNet2DConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
class AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
Dimension for the timestep embeddings.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use only once on the time embeddings before they are passed to the rest of
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
otherwise.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
):
super().__init__()
self.sample_size = sample_size
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type == "image":
# Kandinsky 2.2
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
if time_embedding_act_fn is None:
self.time_embed_act = None
else:
self.time_embed_act = get_activation(time_embedding_act_fn)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
only_cross_attention = [only_cross_attention] * len(down_block_types)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
self.up_blocks[3].attentions[2].proj_out = Identity()
if attention_type in ["gated", "gated-text-image"]:
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image":
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb, hint = self.add_embedding(image_embs, hint)
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
down_block_res_samples += res_samples
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_block_additional_residuals) > 0
and sample.shape == down_block_additional_residuals[0].shape
):
sample += down_block_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
================================================
FILE: magicanimate/models/attention.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward, AdaLayerNorm
from diffusers.models.attention import Attention as CrossAttention
from einops import rearrange, repeat
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# Input
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
# JH: need not repeat when a list of prompts are given
if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length
)
# Output
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention = None,
unet_use_temporal_attention = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
# SC-Attn
assert unet_use_cross_frame_attention is not None
if unet_use_cross_frame_attention:
self.attn1 = SparseCausalAttention2D(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
else:
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.attn2 is not None:
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
# if self.only_cross_attention:
# hidden_states = (
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
# )
# else:
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
# pdb.set_trace()
if self.unet_use_cross_frame_attention:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
================================================
FILE: magicanimate/models/controlnet.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from .embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetOutput(BaseOutput):
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
):
super().__init__()
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=attention_head_dim[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=attention_head_dim[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate Controlnet class from UNet2DConditionModel.
Parameters:
unet (`UNet2DConditionModel`):
UNet model which weights are copied to the ControlNet. Note that all configuration options are also
copied where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
# @property
# # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# def attn_processors(self) -> Dict[str, AttentionProcessor]:
# r"""
# Returns:
# `dict` of attention processors: A dictionary containing all attention processors used in the model with
# indexed by its weight name.
# """
# # set recursively
# processors = {}
# def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# if hasattr(module, "set_processor"):
# processors[f"{name}.processor"] = module.processor
# for sub_name, child in module.named_children():
# fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# return processors
# for name, module in self.named_children():
# fn_recursive_add_processors(name, module, processors)
# return processors
# # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
# r"""
# Parameters:
# `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
# The instantiated processor class or a dictionary of processor classes that will be set as the processor
# of **all** `Attention` layers.
# In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
# """
# count = len(self.attn_processors.keys())
# if isinstance(processor, dict) and len(processor) != count:
# raise ValueError(
# f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
# f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
# )
# def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# if hasattr(module, "set_processor"):
# if not isinstance(processor, dict):
# module.set_processor(processor)
# else:
# module.set_processor(processor.pop(f"{name}.processor"))
# for sub_name, child in module.named_children():
# fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# for name, module in self.named_children():
# fn_recursive_attn_processor(name, module, processor)
# # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# def set_default_attn_processor(self):
# """
# Disables custom attention processors and sets the default attention implementation.
# """
# self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample += controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# cross_attention_kwargs=cross_attention_kwargs,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
================================================
FILE: magicanimate/models/embeddings.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import numpy as np
import torch
from torch import nn
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent + self.pos_embed
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
if set_W_to_weight:
# to delete later
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
def forward(self, x):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""
def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()
self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim
self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)
def forward(self, index):
emb = self.emb(index)
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)
pos_emb = height_emb + width_emb
# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb
class LabelEmbedding(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
Args:
num_classes (`int`): The number of classes.
hidden_size (`int`): The size of the vector embeddings.
dropout_prob (`float`): The probability of dropping a label.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = torch.tensor(force_drop_ids == 1)
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
def forward(self, timestep, class_labels, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
class_labels = self.class_embedder(class_labels) # (N, D)
conditioning = timesteps_emb + class_labels # (N, D)
return conditioning
================================================
FILE: magicanimate/models/motion_module.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Adapted from https://github.com/guoyww/AnimateDiff
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward
from magicanimate.models.orig_attention import CrossAttention
from einops import rearrange, repeat
import math
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def get_motion_module(
in_channels,
motion_module_type: str,
motion_module_kwargs: dict
):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
temporal_attention_dim_div = 1,
zero_initialize = True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 24
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class VersatileAttention(CrossAttention):
def __init__(
self,
attention_mode = None,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
dropout=0.,
max_len=temporal_position_encoding_max_len
) if (temporal_position_encoding and attention_mode == "Temporal") else None
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
if self.attention_mode == "Temporal":
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
else:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
================================================
FILE: magicanimate/models/mutual_self_attention.py
================================================
# Copyright 2023 ByteDance and/or its affiliates.
#
# Copyright (2023) MagicAnimate Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import torch
import torch.nn.functional as F
from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.attention import BasicTransformerBlock
from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from .stable_diffusion_controlnet_reference import torch_dfs
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MutualSelfAttentionControl(AttentionBase):
def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'):
"""
Mutual self-attention control for Stable-Diffusion MODEl
Args:
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.hijack = hijack_init_state
self.with_negative_guidance = with_negative_guidance
# alpha: mutual self attention intensity
# TODO: make alpha learnable
self.alpha = appearance_control_alpha
self.GLOBAL_ATTN_QUEUE = []
assert mode in ['enqueue', 'dequeue']
MODE = mode
def attn_batch(self, q, k, v, num_heads, **kwargs):
"""
Performing attention for a batch of queries, keys, and values
"""
b = q.shape[0] // num_heads
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
attn = sim.softmax(-1)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
return out
def mutual_self_attn(self, q, k, v, num_heads, **kwargs):
q_tgt, q_src = q.chunk(2)
k_tgt, k_src = k.chunk(2)
v_tgt, v_src = v.chunk(2)
# out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \
# self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha)
out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs)
out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs)
out = torch.cat([out_tgt, out_src], dim=0)
return out
def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
if self.MODE == 'dequeue' and len(self.kv_queue) > 0:
k_src, v_src = self.kv_queue.pop(0)
out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs)
return out
else:
self.kv_queue.append([k.clone(), v.clone()])
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
def get_queue(self):
return self.GLOBAL_ATTN_QUEUE
def set_queue(self, attn_queue):
self.GLOBAL_ATTN_QUEUE = attn_queue
def clear_queue(self):
self.GLOBAL_ATTN_QUEUE = []
def to(self, dtype):
self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE]
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
class ReferenceAttentionControl():
def __init__(self,
unet,
mode="write",
do_classifier_free_guidance=False,
attention_auto_machine_weight = float('inf'),
gn_auto_machine_weight = 1.0,
style_fidelity = 1.0,
reference_attn=True,
reference_adain=False,
fusion_blocks="midup",
batch_size=1,
) -> None:
# 10. Modify self attention and group norm
self.unet = unet
assert mode in ["read", "write"]
assert fusion_blocks in ["midup", "full"]
self.reference_attn = reference_attn
self.reference_adain = reference_adain
self.fusion_blocks = fusion_blocks
self.register_reference_hooks(
mode,
do_classifier_free_guidance,
attention_auto_machine_weight,
gn_auto_machine_weight,
style_fidelity,
reference_attn,
reference_adain,
fusion_blocks,
batch_size=batch_size,
)
def register_reference_hooks(
self,
mode,
do_classifier_free_guidance,
attention_auto_machine_weight,
gn_auto_machine_weight,
style_fidelity,
reference_attn,
reference_adain,
dtype=torch.float16,
batch_size=1,
num_images_per_prompt=1,
device=torch.device("cpu"),
fusion_blocks='midup',
):
MODE = mode
do_classifier_free_guidance = do_classifier_free_guidance
attention_auto_machine_weight = attention_auto_machine_weight
gn_auto_machine_weight = gn_auto_machine_weight
style_fidelity = style_fidelity
reference_attn = reference_attn
reference_adain = reference_adain
fusion_blocks = fusion_blocks
num_images_per_prompt = num_images_per_prompt
dtype=dtype
if do_classifier_free_guidance:
uc_mask = (
torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
.to(device)
.bool()
)
else:
uc_mask = (
torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
.to(device)
.bool()
)
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
video_length=None,
):
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if self.only_cross_attention:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
else:
if MODE == "write":
self.bank.append(norm_hidden_states.clone())
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if MODE == "read":
self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
hidden_states_uc = self.attn1(norm_hidden_states,
encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
attention_mask=attention_mask) + hidden_states
hidden_states_c = hidden_states_uc.clone()
_uc_mask = uc_mask.clone()
if do_classifier_free_guidance:
if hidden_states.shape[0] != _uc_mask.shape[0]:
_uc_mask = (
torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
.to(device)
.bool()
)
hidden_states_c[_uc_mask] = self.attn1(
norm_hidden_states[_uc_mask],
encoder_hidden_states=norm_hidden_states[_uc_mask],
attention_mask=attention_mask,
) + hidden_states[_uc_mask]
hidden_states = hidden_states_c.clone()
self.bank.clear()
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
def hacked_mid_forward(self, *args, **kwargs):
eps = 1e-6
x = self.original_forward(*args, **kwargs)
if MODE == "write":
if gn_auto_machine_weight >= self.gn_weight:
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
self.mean_bank.append(mean)
self.var_bank.append(var)
if MODE == "read":
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
var_acc = sum(self.var_bank) / float(len(self.var_bank))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
x_uc = (((x - mean) / std) * std_acc) + mean_acc
x_c = x_uc.clone()
if do_classifier_free_guidance and style_fidelity > 0:
x_c[uc_mask] = x[uc_mask]
x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
self.mean_bank = []
self.var_bank = []
return x
def hack_CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
output_states = ()
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if MODE == "write":
if gn_auto_machine_weight >= self.gn_weight:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
self.mean_bank.append([mean])
self.var_bank.append([var])
if MODE == "read":
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if do_classifier_free_guidance and style_fidelity > 0:
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
output_states = output_states + (hidden_states,)
if MODE == "read":
self.mean_bank = []
self.var_bank = []
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
eps = 1e-6
output_states = ()
for i, resnet in enumerate(self.resnets):
hidden_states = resnet(hidden_states, temb)
if MODE == "write":
if gn_auto_machine_weight >= self.gn_weight:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
self.mean_bank.append([mean])
self.var_bank.append([var])
if MODE == "read":
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if do_classifier_free_guidance and style_fidelity > 0:
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
output_states = output_states + (hidden_states,)
if MODE == "read":
self.mean_bank = []
self.var_bank = []
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def hacked_CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if MODE == "write":
if gn_auto_machine_weight >= self.gn_weight:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
self.mean_bank.append([mean])
self.var_bank.append([var])
if MODE == "read":
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if do_classifier_free_guidance and style_fidelity > 0:
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
if MODE == "read":
self.mean_bank = []
self.var_bank = []
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if MODE == "write":
if gn_auto_machine_weight >= self.gn_weight:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
self.mean_bank.append([mean])
self.var_bank.append([var])
if MODE == "read":
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
hidden_states_c = hidden_states_uc.clone()
if do_classifier_free_guidance and style_fidelity > 0:
hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
if MODE == "read":
self.mean_bank = []
self.var_bank = []
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
if self.reference_attn:
if self.fusion_blocks == "midup":
attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
elif self.fusion_blocks == "full":
attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
module.bank = []
module.attn_weight = float(i) / float(len(attn_modules))
if self.reference_adain:
gn_modules = [self.unet.mid_block]
self.unet.mid_block.gn_weight = 0
down_blocks = self.unet.down_blocks
for w, module in enumerate(down_blocks):
module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
gn_modules.append(module)
up_blocks = self.unet.up_blocks
for w, module in enumerate(up_blocks):
module.gn_weight = float(w) / float(len(up_blocks))
gn_modules.append(module)
for i, module in enumerate(gn_modules):
if getattr(module, "original_forward", None) is None:
module.original_forward = module.forward
if i == 0:
# mid_block
module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
elif isinstance(module, CrossAttnDownBlock2D):
module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
elif isinstance(module, DownBlock2D):
module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
elif isinstance(module, CrossAttnUpBlock2D):
module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
elif isinstance(module, UpBlock2D):
module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
module.mean_bank = []
module.var_bank = []
module.gn_weight *= 2
def update(self, writer, dtype=torch.float16):
if self.reference_attn:
if self.fusion_blocks == "midup":
reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
elif self.fusion_blocks == "full":
reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)]
writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]
reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
for r, w in zip(reader_attn_modules, writer_attn_modules):
r.bank = [v.clone().to(dtype) for v in w.bank]
# w.bank.clear()
if self.reference_adain:
reader_gn_modules = [self.unet.mid_block]
down_blocks = self.unet.down_blocks
for w, module in enumerate(down_blocks):
reader_gn_modules.append(module)
up_blocks = self.unet.up_blocks
for w, module in enumerate(up_blocks):
reader_gn_modules.append(module)
writer_gn_modules = [writer.unet.mid_block]
down_blocks = writer.unet.down_blocks
for w, module in enumerate(down_blocks):
writer_gn_modules.append(module)
up_blocks = writer.unet.up_blocks
for w, module in enumerate(up_blocks):
writer_gn_modules.append(module)
for r, w in zip(reader_gn_modules, writer_gn_modules):
if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
else:
r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
def clear(self):
if self.reference_attn:
if self.fusion_blocks == "midup":
reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
elif self.fusion_blocks == "full":
reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
for r in reader_attn_modules:
r.bank.clear()
if self.reference_adain:
reader_gn_modules = [self.unet.mid_block]
down_blocks = self.unet.down_blocks
for w, module in enumerate(down_blocks):
reader_gn_modules.append(module)
up_blocks = self.unet.up_blocks
for w, module in enumerate(up_blocks):
reader_gn_modules.append(module)
for r in reader_gn_modules:
r.mean_bank.clear()
r.var_bank.clear()
================================================
FILE: magicanimate/models/orig_attention.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.embeddings import ImagePositionalEmbeddings
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
for the unnoised latent pixels.
"""
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
gitextract_ig9mhcve/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── inference/
│ │ └── inference.yaml
│ └── prompts/
│ └── animation.yaml
├── demo/
│ ├── animate.py
│ ├── animate_dist.py
│ ├── gradio_animate.py
│ └── gradio_animate_dist.py
├── environment.yaml
├── magicanimate/
│ ├── models/
│ │ ├── appearance_encoder.py
│ │ ├── attention.py
│ │ ├── controlnet.py
│ │ ├── embeddings.py
│ │ ├── motion_module.py
│ │ ├── mutual_self_attention.py
│ │ ├── orig_attention.py
│ │ ├── resnet.py
│ │ ├── stable_diffusion_controlnet_reference.py
│ │ ├── unet.py
│ │ ├── unet_3d_blocks.py
│ │ └── unet_controlnet.py
│ ├── pipelines/
│ │ ├── animation.py
│ │ ├── context.py
│ │ └── pipeline_animation.py
│ └── utils/
│ ├── dist_tools.py
│ ├── util.py
│ └── videoreader.py
├── requirements.txt
└── scripts/
├── animate.sh
└── animate_dist.sh
SYMBOL INDEX (264 symbols across 22 files)
FILE: demo/animate.py
class MagicAnimate (line 45) | class MagicAnimate():
method __init__ (line 46) | def __init__(self, config="configs/prompts/animation.yaml") -> None:
method __call__ (line 128) | def __call__(self, source_image, motion_sequence, random_seed, step, g...
FILE: demo/animate_dist.py
class MagicAnimate (line 42) | class MagicAnimate():
method __init__ (line 43) | def __init__(self, args) -> None:
method predict (line 129) | def predict(self, source_image, motion_sequence, random_seed, step, gu...
function distributed_main (line 190) | def distributed_main(device_id, args):
function run (line 200) | def run(args):
FILE: demo/gradio_animate.py
function animate (line 21) | def animate(reference_image, motion_sequence_state, seed, steps, guidanc...
function read_video (line 54) | def read_video(video):
function read_image (line 59) | def read_image(image, size=512):
FILE: demo/gradio_animate_dist.py
function animate (line 23) | def animate(reference_image, motion_sequence, seed, steps, guidance_scale):
function read_video (line 69) | def read_video(video, size=512):
function read_image (line 80) | def read_image(image, size=512):
FILE: magicanimate/models/appearance_encoder.py
class Identity (line 63) | class Identity(torch.nn.Module):
method __init__ (line 83) | def __init__(self, scale=None, *args, **kwargs) -> None:
method forward (line 86) | def forward(self, input, *args, **kwargs):
class _LoRACompatibleLinear (line 91) | class _LoRACompatibleLinear(nn.Module):
method __init__ (line 96) | def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None...
method set_lora_layer (line 100) | def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
method _fuse_lora (line 103) | def _fuse_lora(self):
method _unfuse_lora (line 106) | def _unfuse_lora(self):
method forward (line 109) | def forward(self, hidden_states, scale=None, lora_scale: int = 1):
class UNet2DConditionOutput (line 114) | class UNet2DConditionOutput(BaseOutput):
class AppearanceEncoderModel (line 126) | class AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoa...
method __init__ (line 217) | def __init__(
method attn_processors (line 636) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 659) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
method set_default_attn_processor (line 693) | def set_default_attn_processor(self):
method set_attention_slice (line 708) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 773) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 777) | def forward(
FILE: magicanimate/models/attention.py
class Transformer3DModelOutput (line 37) | class Transformer3DModelOutput(BaseOutput):
class Transformer3DModel (line 48) | class Transformer3DModel(ModelMixin, ConfigMixin):
method __init__ (line 50) | def __init__(
method forward (line 112) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
class BasicTransformerBlock (line 164) | class BasicTransformerBlock(nn.Module):
method __init__ (line 165) | def __init__(
method set_use_memory_efficient_attention_xformers (line 248) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
method forward (line 276) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
FILE: magicanimate/models/controlnet.py
class ControlNetOutput (line 44) | class ControlNetOutput(BaseOutput):
class ControlNetConditioningEmbedding (line 49) | class ControlNetConditioningEmbedding(nn.Module):
method __init__ (line 59) | def __init__(
method forward (line 81) | def forward(self, conditioning):
class ControlNetModel (line 94) | class ControlNetModel(ModelMixin, ConfigMixin):
method __init__ (line 98) | def __init__(
method from_unet (line 267) | def from_unet(
method set_attention_slice (line 384) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 449) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 453) | def forward(
function zero_module (line 575) | def zero_module(module):
FILE: magicanimate/models/embeddings.py
function get_timestep_embedding (line 28) | def get_timestep_embedding(
function get_2d_sincos_pos_embed (line 71) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra...
function get_2d_sincos_pos_embed_from_grid (line 88) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
function get_1d_sincos_pos_embed_from_grid (line 100) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
class PatchEmbed (line 121) | class PatchEmbed(nn.Module):
method __init__ (line 124) | def __init__(
method forward (line 152) | def forward(self, latent):
class TimestepEmbedding (line 161) | class TimestepEmbedding(nn.Module):
method __init__ (line 162) | def __init__(
method forward (line 206) | def forward(self, sample, condition=None):
class Timesteps (line 221) | class Timesteps(nn.Module):
method __init__ (line 222) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
method forward (line 228) | def forward(self, timesteps):
class GaussianFourierProjection (line 238) | class GaussianFourierProjection(nn.Module):
method __init__ (line 241) | def __init__(
method forward (line 255) | def forward(self, x):
class ImagePositionalEmbeddings (line 268) | class ImagePositionalEmbeddings(nn.Module):
method __init__ (line 292) | def __init__(
method forward (line 310) | def forward(self, index):
class LabelEmbedding (line 333) | class LabelEmbedding(nn.Module):
method __init__ (line 343) | def __init__(self, num_classes, hidden_size, dropout_prob):
method token_drop (line 350) | def token_drop(self, labels, force_drop_ids=None):
method forward (line 361) | def forward(self, labels, force_drop_ids=None):
class CombinedTimestepLabelEmbeddings (line 369) | class CombinedTimestepLabelEmbeddings(nn.Module):
method __init__ (line 370) | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
method forward (line 377) | def forward(self, timestep, class_labels, hidden_dtype=None):
FILE: magicanimate/models/motion_module.py
function zero_module (line 23) | def zero_module(module):
class TemporalTransformer3DModelOutput (line 31) | class TemporalTransformer3DModelOutput(BaseOutput):
function get_motion_module (line 42) | def get_motion_module(
class VanillaTemporalModule (line 53) | class VanillaTemporalModule(nn.Module):
method __init__ (line 54) | def __init__(
method forward (line 82) | def forward(self, input_tensor, temb, encoder_hidden_states, attention...
class TemporalTransformer3DModel (line 90) | class TemporalTransformer3DModel(nn.Module):
method __init__ (line 91) | def __init__(
method forward (line 139) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
class TemporalTransformerBlock (line 166) | class TemporalTransformerBlock(nn.Module):
method __init__ (line 167) | def __init__(
method forward (line 215) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
class PositionalEncoding (line 230) | class PositionalEncoding(nn.Module):
method __init__ (line 231) | def __init__(
method forward (line 246) | def forward(self, x):
class VersatileAttention (line 251) | class VersatileAttention(CrossAttention):
method __init__ (line 252) | def __init__(
method extra_repr (line 272) | def extra_repr(self):
method forward (line 275) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
FILE: magicanimate/models/mutual_self_attention.py
class AttentionBase (line 24) | class AttentionBase:
method __init__ (line 25) | def __init__(self):
method after_step (line 30) | def after_step(self):
method __call__ (line 33) | def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_he...
method forward (line 43) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
method reset (line 48) | def reset(self):
class MutualSelfAttentionControl (line 53) | class MutualSelfAttentionControl(AttentionBase):
method __init__ (line 55) | def __init__(self, total_steps=50, hijack_init_state=True, with_negati...
method attn_batch (line 73) | def attn_batch(self, q, k, v, num_heads, **kwargs):
method mutual_self_attn (line 88) | def mutual_self_attn(self, q, k, v, num_heads, **kwargs):
method mutual_self_attn_wq (line 100) | def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_u...
method get_queue (line 109) | def get_queue(self):
method set_queue (line 112) | def set_queue(self, attn_queue):
method clear_queue (line 115) | def clear_queue(self):
method to (line 118) | def to(self, dtype):
method forward (line 121) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class ReferenceAttentionControl (line 128) | class ReferenceAttentionControl():
method __init__ (line 130) | def __init__(self,
method register_reference_hooks (line 161) | def register_reference_hooks(
method update (line 577) | def update(self, writer, dtype=torch.float16):
method clear (line 619) | def clear(self):
FILE: magicanimate/models/orig_attention.py
class Transformer2DModelOutput (line 36) | class Transformer2DModelOutput(BaseOutput):
class Transformer2DModel (line 54) | class Transformer2DModel(ModelMixin, ConfigMixin):
method __init__ (line 93) | def __init__(
method forward (line 184) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
class AttentionBlock (line 253) | class AttentionBlock(nn.Module):
method __init__ (line 271) | def __init__(
method set_use_memory_efficient_attention_xformers (line 296) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
method reshape_heads_to_batch_dim (line 320) | def reshape_heads_to_batch_dim(self, tensor):
method reshape_batch_dim_to_heads (line 327) | def reshape_batch_dim_to_heads(self, tensor):
method forward (line 334) | def forward(self, hidden_states):
class BasicTransformerBlock (line 388) | class BasicTransformerBlock(nn.Module):
method __init__ (line 405) | def __init__(
method set_use_memory_efficient_attention_xformers (line 458) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
method forward (line 485) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
class CrossAttention (line 516) | class CrossAttention(nn.Module):
method __init__ (line 531) | def __init__(
method reshape_heads_to_batch_dim (line 578) | def reshape_heads_to_batch_dim(self, tensor):
method reshape_batch_dim_to_heads (line 585) | def reshape_batch_dim_to_heads(self, tensor):
method set_attention_slice (line 592) | def set_attention_slice(self, slice_size):
method forward (line 598) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
method _attention (line 655) | def _attention(self, query, key, value, attention_mask=None):
method _sliced_attention (line 686) | def _sliced_attention(self, query, key, value, sequence_length, dim, a...
method _memory_efficient_attention_xformers (line 729) | def _memory_efficient_attention_xformers(self, query, key, value, atte...
class FeedForward (line 739) | class FeedForward(nn.Module):
method __init__ (line 751) | def __init__(
method forward (line 778) | def forward(self, hidden_states):
class GELU (line 784) | class GELU(nn.Module):
method __init__ (line 789) | def __init__(self, dim_in: int, dim_out: int):
method gelu (line 793) | def gelu(self, gate):
method forward (line 799) | def forward(self, hidden_states):
class GEGLU (line 806) | class GEGLU(nn.Module):
method __init__ (line 815) | def __init__(self, dim_in: int, dim_out: int):
method gelu (line 819) | def gelu(self, gate):
method forward (line 825) | def forward(self, hidden_states):
class ApproximateGELU (line 830) | class ApproximateGELU(nn.Module):
method __init__ (line 837) | def __init__(self, dim_in: int, dim_out: int):
method forward (line 841) | def forward(self, x):
class AdaLayerNorm (line 846) | class AdaLayerNorm(nn.Module):
method __init__ (line 851) | def __init__(self, embedding_dim, num_embeddings):
method forward (line 858) | def forward(self, x, timestep):
class DualTransformer2DModel (line 865) | class DualTransformer2DModel(nn.Module):
method __init__ (line 892) | def __init__(
method forward (line 941) | def forward(
FILE: magicanimate/models/resnet.py
class InflatedConv3d (line 30) | class InflatedConv3d(nn.Conv2d):
method forward (line 31) | def forward(self, x):
class Upsample3D (line 41) | class Upsample3D(nn.Module):
method __init__ (line 42) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
method forward (line 56) | def forward(self, hidden_states, output_size=None):
class Downsample3D (line 87) | class Downsample3D(nn.Module):
method __init__ (line 88) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
method forward (line 102) | def forward(self, hidden_states):
class ResnetBlock3D (line 113) | class ResnetBlock3D(nn.Module):
method __init__ (line 114) | def __init__(
method forward (line 177) | def forward(self, input_tensor, temb):
class Mish (line 210) | class Mish(torch.nn.Module):
method forward (line 211) | def forward(self, hidden_states):
FILE: magicanimate/models/stable_diffusion_controlnet_reference.py
function torch_dfs (line 65) | def torch_dfs(model: torch.nn.Module):
class StableDiffusionControlNetReferencePipeline (line 72) | class StableDiffusionControlNetReferencePipeline(StableDiffusionControlN...
method prepare_ref_latents (line 73) | def prepare_ref_latents(self, refimage, batch_size, dtype, device, gen...
method __call__ (line 104) | def __call__(
FILE: magicanimate/models/unet.py
class UNet3DConditionOutput (line 53) | class UNet3DConditionOutput(BaseOutput):
class UNet3DConditionModel (line 57) | class UNet3DConditionModel(ModelMixin, ConfigMixin):
method __init__ (line 61) | def __init__(
method set_attention_slice (line 262) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 327) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 331) | def forward(
method from_pretrained_2d (line 470) | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, une...
FILE: magicanimate/models/unet_3d_blocks.py
function get_down_block (line 30) | def get_down_block(
function get_up_block (line 106) | def get_up_block(
class UNetMidBlock3DCrossAttn (line 181) | class UNetMidBlock3DCrossAttn(nn.Module):
method __init__ (line 182) | def __init__(
method forward (line 276) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None...
class CrossAttnDownBlock3D (line 286) | class CrossAttnDownBlock3D(nn.Module):
method __init__ (line 287) | def __init__(
method forward (line 384) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None...
class DownBlock3D (line 426) | class DownBlock3D(nn.Module):
method __init__ (line 427) | def __init__(
method forward (line 491) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
class CrossAttnUpBlock3D (line 522) | class CrossAttnUpBlock3D(nn.Module):
method __init__ (line 523) | def __init__(
method forward (line 616) | def forward(
class UpBlock3D (line 665) | class UpBlock3D(nn.Module):
method __init__ (line 666) | def __init__(
method forward (line 726) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
FILE: magicanimate/models/unet_controlnet.py
class UNet3DConditionOutput (line 50) | class UNet3DConditionOutput(BaseOutput):
class UNet3DConditionModel (line 54) | class UNet3DConditionModel(ModelMixin, ConfigMixin):
method __init__ (line 58) | def __init__(
method set_attention_slice (line 259) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 324) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 328) | def forward(
method from_pretrained_2d (line 486) | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, une...
FILE: magicanimate/pipelines/animation.py
function main (line 46) | def main(args):
function distributed_main (line 246) | def distributed_main(device_id, args):
function run (line 256) | def run(args):
FILE: magicanimate/pipelines/context.py
function ordered_halving (line 12) | def ordered_halving(val):
function uniform (line 20) | def uniform(
function get_context_scheduler (line 45) | def get_context_scheduler(name: str) -> Callable:
function get_total_steps (line 52) | def get_total_steps(
FILE: magicanimate/pipelines/pipeline_animation.py
class AnimationPipelineOutput (line 69) | class AnimationPipelineOutput(BaseOutput):
class AnimationPipeline (line 73) | class AnimationPipeline(DiffusionPipeline):
method __init__ (line 76) | def __init__(
method enable_vae_slicing (line 152) | def enable_vae_slicing(self):
method disable_vae_slicing (line 155) | def disable_vae_slicing(self):
method enable_sequential_cpu_offload (line 158) | def enable_sequential_cpu_offload(self, gpu_id=0):
method _execution_device (line 172) | def _execution_device(self):
method _encode_prompt (line 184) | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_cla...
method decode_latents (line 273) | def decode_latents(self, latents, rank, decoder_consistency=None):
method prepare_extra_step_kwargs (line 291) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 308) | def check_inputs(self, prompt, height, width, callback_steps):
method prepare_latents (line 323) | def prepare_latents(self, batch_size, num_channels_latents, video_leng...
method prepare_condition (line 352) | def prepare_condition(self, condition, num_videos_per_prompt, device, ...
method next_step (line 361) | def next_step(
method images2latents (line 385) | def images2latents(self, images, dtype):
method invert (line 399) | def invert(
method interpolate_latents (line 461) | def interpolate_latents(self, latents: torch.Tensor, interpolation_fac...
method select_controlnet_res_samples (line 496) | def select_controlnet_res_samples(self, controlnet_res_samples_cache_d...
method __call__ (line 525) | def __call__(
FILE: magicanimate/utils/dist_tools.py
function distributed_init (line 18) | def distributed_init(args):
function get_rank (line 62) | def get_rank():
function is_master (line 72) | def is_master():
function synchronize (line 76) | def synchronize():
function suppress_output (line 81) | def suppress_output(is_master):
FILE: magicanimate/utils/util.py
function save_videos_grid (line 21) | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_r...
function save_images_grid (line 35) | def save_images_grid(images: torch.Tensor, path: str):
function init_prompt (line 45) | def init_prompt(prompt, pipeline):
function next_step (line 64) | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timest...
function get_noise_pred_single (line 77) | def get_noise_pred_single(latents, t, context, unet):
function ddim_loop (line 83) | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
function ddim_inversion (line 97) | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps...
function video2images (line 102) | def video2images(path, step=4, length=16, start=0):
function images2video (line 111) | def images2video(video, path, fps=8):
function get_tensor_interpolation_method (line 118) | def get_tensor_interpolation_method():
function set_tensor_interpolation_method (line 121) | def set_tensor_interpolation_method(is_slerp):
function linear (line 125) | def linear(v1, v2, t):
function slerp (line 128) | def slerp(
FILE: magicanimate/utils/videoreader.py
class VideoReader (line 31) | class VideoReader():
method __init__ (line 37) | def __init__(self, video, num_frames=float("inf"), decode_lossy=False,...
method seek (line 61) | def seek(self, pts, backward=True, any_frame=False):
method _occasional_gc (line 65) | def _occasional_gc(self):
method _read_video (line 73) | def _read_video(self, offset):
method _iter_frames (line 90) | def _iter_frames(self):
method _compute_video_stats (line 95) | def _compute_video_stats(self):
method _get_video_frame_rate (line 113) | def _get_video_frame_rate(self):
method sample (line 116) | def sample(self, debug=False):
method read_frames (line 141) | def read_frames(self, frame_indices):
method read (line 150) | def read(self):
method get_num_frames (line 155) | def get_num_frames(self):
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (441K chars).
[
{
"path": ".gitignore",
"chars": 122,
"preview": "__pycache__\n.vscode\nsamples\nxformers\nsrc\nthird_party\nbackup\npretrained_models\n*.nfs*\n./*.png\n./*.mp4\ndemo/tmp\ndemo/outpu"
},
{
"path": "LICENSE",
"chars": 1519,
"preview": "BSD 3-Clause License\n\nCopyright 2023 MagicAnimate Team All rights reserved.\n\nRedistribution and use in source and binary"
},
{
"path": "README.md",
"chars": 4475,
"preview": "<!-- # magic-edit.github.io -->\n\n<p align=\"center\">\n\n <h2 align=\"center\">MagicAnimate: Temporally Consistent Human Imag"
},
{
"path": "configs/inference/inference.yaml",
"chars": 631,
"preview": "unet_additional_kwargs:\n unet_use_cross_frame_attention: false\n unet_use_temporal_attention: false\n use_motion_module"
},
{
"path": "configs/prompts/animation.yaml",
"chars": 1279,
"preview": "pretrained_model_path: \"pretrained_models/stable-diffusion-v1-5\"\npretrained_vae_path: \"pretrained_models/sd-vae-ft-mse\"\n"
},
{
"path": "demo/animate.py",
"chars": 9046,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "demo/animate_dist.py",
"chars": 10860,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "demo/gradio_animate.py",
"chars": 4076,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "demo/gradio_animate_dist.py",
"chars": 5011,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "environment.yaml",
"chars": 5226,
"preview": "name: manimate\nchannels:\n - conda-forge\n - defaults\ndependencies:\n - _libgcc_mutex=0.1=main\n - _openmp_mutex=5.1=1_g"
},
{
"path": "magicanimate/models/appearance_encoder.py",
"chars": 54454,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/attention.py",
"chars": 13183,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/controlnet.py",
"chars": 25435,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/embeddings.py",
"chars": 13194,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/motion_module.py",
"chars": 13113,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/mutual_self_attention.py",
"chars": 31295,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "magicanimate/models/orig_attention.py",
"chars": 43314,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/resnet.py",
"chars": 7799,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/stable_diffusion_controlnet_reference.py",
"chars": 42910,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/unet.py",
"chars": 21745,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/unet_3d_blocks.py",
"chars": 28720,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/models/unet_controlnet.py",
"chars": 22634,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/pipelines/animation.py",
"chars": 11780,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "magicanimate/pipelines/context.py",
"chars": 2271,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/pipelines/pipeline_animation.py",
"chars": 36337,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/utils/dist_tools.py",
"chars": 2977,
"preview": "# Copyright 2023 ByteDance and/or its affiliates.\n#\n# Copyright (2023) MagicAnimate Authors\n#\n# ByteDance, its affiliate"
},
{
"path": "magicanimate/utils/util.py",
"chars": 4979,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "magicanimate/utils/videoreader.py",
"chars": 6093,
"preview": "# *************************************************************************\n# This file may have been modified by Byteda"
},
{
"path": "requirements.txt",
"chars": 2305,
"preview": "absl-py==1.4.0\naccelerate==0.22.0\naiofiles==23.2.1\naiohttp==3.8.5\naiosignal==1.3.1\naltair==5.0.1\nannotated-types==0.5.0\n"
},
{
"path": "scripts/animate.sh",
"chars": 84,
"preview": "python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml\n"
},
{
"path": "scripts/animate_dist.sh",
"chars": 91,
"preview": "python3 -m magicanimate.pipelines.animation --config configs/prompts/animation.yaml --dist\n"
}
]
About this extraction
This page contains the full source code of the magic-research/magic-animate GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (417.0 KB), approximately 94.9k tokens, and a symbol index with 264 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.