Showing preview only (659K chars total). Download the full file or copy to clipboard to get everything.
Repository: PKU-YuanGroup/Cycle3D
Branch: main
Commit: 5ddc45691c8e
Files: 34
Total size: 638.0 KB
Directory structure:
gitextract_hayw1hga/
├── LICENSE
├── README.md
├── acc_configs/
│ ├── gpu1.yaml
│ ├── gpu4.yaml
│ ├── gpu6.yaml
│ ├── gpu7.yaml
│ ├── gpu8.yaml
│ ├── hostfile
│ ├── multi_node.yaml
│ └── zero2.json
├── core/
│ ├── __init__.py
│ ├── attention.py
│ ├── control.py
│ ├── diffuser_utils.py
│ ├── gs.py
│ ├── masactrl.py
│ ├── masactrl_utils.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── transformer_mv2d.py
│ │ ├── unet_mv2d_blocks.py
│ │ ├── unet_mv2d_condition.py
│ │ ├── unet_mv2d_condition_depth.py
│ │ ├── unet_mv2d_condition_depth_diffusion.py
│ │ └── unet_mv2d_condition_depth_diffusion_test.py
│ ├── models_LGM_compos_diffusion.py
│ ├── models_LGM_compos_diffusion_validate_inversion_2_masa.py
│ ├── options_latents_diffusion.py
│ ├── provider_Gobjaverse_latent_diffusion_insert.py
│ ├── unet_LGM_compos.py
│ └── utils.py
├── infer_ours_masa.py
├── main_resume_compose.py
└── mvdream/
├── mv_unet.py
└── pipeline_mvdream.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 PKU-YUAN-Lab (袁粒课题组-北大信工)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
<h2 align="center"> <a href="https://github.com/PKU-YuanGroup/Cycle3D">Cycle3D: High-quality and Consistent Image-to-3D Generation via
Generation-Reconstruction Cycle</a></h2>
<h5 align="center"> If you like our project, please give us a star ⭐ on GitHub for latest update. </h2>
<h5 align="center">
[](https://PKU-YuanGroup.github.io/Cycle3D/)
[](https://arxiv.org/abs/2407.19548)
[](https://github.com/PKU-YuanGroup/repaint123/blob/main/LICENSE)
</h5>
## [Project page](https://PKU-YuanGroup.github.io/Cycle3D/) | [Paper](https://arxiv.org/abs/2407.19548) | [Live Demo (Coming Soon)]()

## 😮 Highlights
### 🔥 Generation-Reconstruction cycle for the unified diffusion process
- The pre-trained 2D diffusion model trained on billions of web images can generate high-quality texture.
- The reconstruction model can ensure consistency across multi-views.
- We cyclically utilizes a 2D diffusion-based generation module and a feed-forward 3D reconstruction module during the multi-step diffusion process.
## 🚩 **Updates**
Welcome to **watch** 👀 this repository for the latest updates.
✅ **[2024.7.28]** : We have released our paper, Cycle3D on [arXiv](https://arxiv.org/abs/2407.19548).
✅ **[2024.7.28]** : Release [project page](https://PKU-YuanGroup.github.io/Cycle3D/).
- [ ] Code release.
- [ ] Online Demo.
## 🤗 Demo
Coming soon!
## 🚀 Image-to-3D Results
### Qualitative comparison

### Quantitative comparison

## 👍 **Acknowledgement**
This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!
* [LGM](https://github.com/3DTopia/LGM)
* [MasaCtrl](https://github.com/TencentARC/MasaCtrl)
* [Diffusers](https://github.com/huggingface/diffusers)
## ✏️ Citation
If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:.
```BibTeX
@misc{tang2024cycle3dhighqualityconsistentimageto3d,
title={Cycle3D: High-quality and Consistent Image-to-3D Generation via Generation-Reconstruction Cycle},
author={Zhenyu Tang and Junwu Zhang and Xinhua Cheng and Wangbo Yu and Chaoran Feng and Yatian Pang and Bin Lin and Li Yuan},
year={2024},
eprint={2407.19548},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.19548},
}
```
<!---->
================================================
FILE: acc_configs/gpu1.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
# distributed_type: DEEPSPEED
# deepspeed_config:
# gradient_clipping: 1.0
# zero_stage: 2
================================================
FILE: acc_configs/gpu4.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: acc_configs/gpu6.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 6
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
distributed_type: DEEPSPEED
deepspeed_config:
gradient_clipping: 1.0
zero_stage: 2
================================================
FILE: acc_configs/gpu7.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 7
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
distributed_type: DEEPSPEED
deepspeed_config:
gradient_clipping: 1.0
zero_stage: 2
================================================
FILE: acc_configs/gpu8.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
distributed_type: DEEPSPEED
deepspeed_config:
gradient_clipping: 1.0
zero_stage: 2
================================================
FILE: acc_configs/hostfile
================================================
gpu147 slots=8
gpu176 slots=8
gpu47 slots=8
gpu117 slots=8
================================================
FILE: acc_configs/multi_node.yaml
================================================
compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
deepspeed_config:
gradient_clipping: 1.0
zero_stage: 2
deepspeed_hostfile: /remote-home1/yeyang/aigc/aigc/LGM/acc_configs/hostfile
fsdp_config: {}
machine_rank: 0
main_process_ip: 219.223.196.147
main_process_port: 29504
main_training_function: main
num_machines: 4
num_processes: 32
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: acc_configs/zero2.json
================================================
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
================================================
FILE: core/__init__.py
================================================
================================================
FILE: core/attention.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import os
import warnings
from torch import Tensor
from torch import nn
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Attention)")
else:
warnings.warn("xFormers is disabled (Attention)")
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Attention)")
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossAttention(nn.Module):
def __init__(
self,
dim: int,
dim_q: int,
dim_k: int,
dim_v: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# q: [B, N, Cq]
# k: [B, M, Ck]
# v: [B, M, Cv]
# return: [B, N, C]
B, N, _ = q.shape
M = k.shape[1]
q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]
k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
attn = q @ k.transpose(-2, -1) # [B, nh, N, M]
attn = attn.softmax(dim=-1) # [B, nh, N, M]
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffCrossAttention(CrossAttention):
def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)
B, N, _ = q.shape
M = k.shape[1]
q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]
k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
================================================
FILE: core/control.py
================================================
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
from typing import Callable, List, Optional, Union, Dict, Any
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from tqdm import tqdm
from transformers import (
CLIPTextModel,
CLIPTokenizer,
DPTFeatureExtractor,
DPTForDepthEstimation,
)
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
from diffusers.pipelines import StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
import kiui
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ControlNetPipeline(StableDiffusionControlNetPipeline):
def pred_x0(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta: float=0.0,
verbose=False,
):
"""
predict the sampe the next step in the denoise process.
"""
alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x.device)
alpha_prod_t = alphas_cumprod [timestep]
B = alpha_prod_t.shape[0]
alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1)
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
return pred_x0
def next_step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta=0.,
verbose=False
):
"""
Inverse sampling for DDIM Inversion
"""
if verbose:
print("timestep: ", timestep)
next_step = timestep
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
return x_next, pred_x0
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
data = None,
LGM_unet = None,
opt = None,
pos_act = None,
scale_act = None,
opacity_act = None,
rot_act = None,
rgb_act = None,
gs = None,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for
input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
where a list of image lists can be passed to batch for each prompt and each ControlNet.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
The ControlNet encoder tries to recognize the content of the input image even if you remove all
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
self.opt = opt
self.pos_act = pos_act
self.scale_act = scale_act
self.opacity_act = opacity_act
self.rot_act = rot_act
self.rgb_act = rgb_act
self.gs = gs
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
# 1. Check inputs. Raise error if not correct
# self.check_inputs(
# prompt,
# image,
# callback_steps,
# negative_prompt,
# prompt_embeds,
# negative_prompt_embeds,
# ip_adapter_image,
# ip_adapter_image_embeds,
# controlnet_conditioning_scale,
# control_guidance_start,
# control_guidance_end,
# callback_on_step_end_tensor_inputs,
# )
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
if prompt_embeds is not None:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel):
images = []
# Nested lists as ControlNet condition
if isinstance(image[0], list):
# Transpose the nested image list
image = [list(t) for t in zip(*image)]
for image_ in image:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
image = images
height, width = image[0].shape[-2:]
else:
assert False
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)
# 7.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual
noise_pred, blocks_sample, tembpred_noise = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
pred_x0 = self.pred_x0(noise_pred, timestep, latent_model_input)
images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5
images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False)
images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].flatten(0, 1).to(self.opt.weight_dtype) ], dim=1)
# perform guidance
# compute the previous noisy sample x_t -> x_t-1
#latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
has_nsfw_concept = None
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return images
================================================
FILE: core/diffuser_utils.py
================================================
"""
Util functions based on Diffuser framework.
"""
import os
import torch
import cv2
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torchvision.utils import save_image
from torchvision.io import read_image
from diffusers import StableDiffusionPipeline
from pytorch_lightning import seed_everything
class MasaCtrlPipeline(StableDiffusionPipeline):
def next_step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta=0.,
verbose=False
):
"""
Inverse sampling for DDIM Inversion
"""
if verbose:
print("timestep: ", timestep)
next_step = timestep
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
return x_next, pred_x0
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta: float=0.0,
verbose=False,
):
"""
predict the sampe the next step in the denoise process.
"""
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
return x_prev, pred_x0
@torch.no_grad()
def image2latent(self, image):
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if type(image) is Image:
image = np.array(image)
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
# input image density range [-1, 1]
latents = self.vae.encode(image)['latent_dist'].mean
latents = latents * 0.18215
return latents
@torch.no_grad()
def latent2image(self, latents, return_type='np'):
latents = 1 / 0.18215 * latents.detach()
image = self.vae.decode(latents)['sample']
if return_type == 'np':
image = (image / 2 + 0.5).clamp(0, 1)
image = image.to(torch.float).cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype(np.uint8)
elif return_type == "pt":
image = (image / 2 + 0.5).clamp(0, 1)
return image
def latent2image_grad(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)['sample']
return image # range [-1, 1]
@torch.no_grad()
def __call__(
self,
prompt,
batch_size=1,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
eta=0.0,
latents=None,
unconditioning=None,
neg_prompt=None,
ref_intermediate_latents=None,
return_intermediates=False,
**kwds):
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if isinstance(prompt, list):
batch_size = len(prompt)
elif isinstance(prompt, str):
if batch_size > 1:
prompt = [prompt] * batch_size
# text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=77,
return_tensors="pt"
)
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
print("input text embeddings :", text_embeddings.shape)
if kwds.get("dir"):
dir = text_embeddings[-2] - text_embeddings[-1]
u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
print(u.shape)
print(v.shape)
# define initial latents
latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
if latents is None:
latents = torch.randn(latents_shape, device=DEVICE)
else:
assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
# unconditional embedding for classifier free guidance
if guidance_scale > 1.:
max_length = text_input.input_ids.shape[-1]
if neg_prompt:
uc_text = neg_prompt
else:
uc_text = ""
# uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
unconditional_input = self.tokenizer(
[uc_text] * batch_size,
padding="max_length",
max_length=77,
return_tensors="pt"
)
# unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
print("latents shape: ", latents.shape)
# iterative sampling
self.scheduler.set_timesteps(num_inference_steps)
# print("Valid timesteps: ", reversed(self.scheduler.timesteps))
latents_list = [latents]
pred_x0_list = [latents]
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
if ref_intermediate_latents is not None:
# note that the batch_size >= 2
latents_ref = ref_intermediate_latents[-1 - i]
_, latents_cur = latents.chunk(2)
latents = torch.cat([latents_ref, latents_cur])
if guidance_scale > 1.:
model_inputs = torch.cat([latents] * 2)
else:
model_inputs = latents
if unconditioning is not None and isinstance(unconditioning, list):
_, text_embeddings = text_embeddings.chunk(2)
text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
# predict tghe noise
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
if guidance_scale > 1.:
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
# compute the previous noise sample x_t -> x_t-1
latents, pred_x0 = self.step(noise_pred, t, latents)
latents_list.append(latents)
pred_x0_list.append(pred_x0)
image = self.latent2image(latents, return_type="pt")
if return_intermediates:
pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
return image, pred_x0_list, latents_list
return image
@torch.no_grad()
def invert(
self,
image: torch.Tensor,
prompt,
num_inference_steps=50,
guidance_scale=7.5,
eta=0.0,
return_intermediates=False,
path = None,
**kwds):
"""
invert a real image into noise map with determinisc DDIM inversion
"""
DEVICE = image.device
batch_size = image.shape[0]
if isinstance(prompt, list):
if batch_size == 1:
image = image.expand(len(prompt), -1, -1, -1)
elif isinstance(prompt, str):
if batch_size > 1:
prompt = [prompt] * batch_size
# text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=77,
return_tensors="pt"
)
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
print("input text embeddings :", text_embeddings.shape)
# define initial latents
latents = self.image2latent(image)
start_latents = latents
# print(latents)
# exit()
# unconditional embedding for classifier free guidance
if guidance_scale > 1.:
max_length = text_input.input_ids.shape[-1]
unconditional_input = self.tokenizer(
[""] * batch_size,
padding="max_length",
max_length=77,
return_tensors="pt"
)
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
print("latents shape: ", latents.shape)
# interative sampling
self.scheduler.set_timesteps(num_inference_steps)
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
# print("attributes: ", self.scheduler.__dict__)
latents_list = [latents]
pred_x0_list = [latents]
for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
if guidance_scale > 1.:
model_inputs = torch.cat([latents] * 2)
else:
model_inputs = latents
# predict the noise
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
if guidance_scale > 1.:
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
# compute the previous noise sample x_t-1 -> x_t
latents, pred_x0 = self.next_step(noise_pred, t, latents)
#Image.fromarray(self.latent2image(latents[:1])).save(os.path.join(path, str(i)+'_8.png'))
# if kwds.get("workspace"):
# Image.fromarray(self.latent2image(pred_x0[:1])).save(kwds.get("workspace")+'/'+str(i)+'_8.png')
latents_list.append(latents)
pred_x0_list.append(pred_x0)
if return_intermediates:
# return the intermediate laters during inversion
# pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
return latents, latents_list
return latents, start_latents
================================================
FILE: core/gs.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from core.options import Options
import kiui
class GaussianRenderer:
def __init__(self, opt: Options):
self.opt = opt
self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
# intrinsics
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
self.proj_matrix[2, 3] = 1
def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):
# gaussians: [B, N, 14]
# cam_view, cam_view_proj: [B, V, 4, 4]
# cam_pos: [B, V, 3]
device = gaussians.device
B, V = cam_view.shape[:2]
# loop of loop...
images = []
alphas = []
for b in range(B):
# pos, opacity, scale, rotation, shs
means3D = gaussians[b, :, 0:3].contiguous().float()
opacity = gaussians[b, :, 3:4].contiguous().float()
scales = gaussians[b, :, 4:7].contiguous().float()
rotations = gaussians[b, :, 7:11].contiguous().float()
rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
for v in range(V):
# render novel views
view_matrix = cam_view[b, v].float()
view_proj_matrix = cam_view_proj[b, v].float()
campos = cam_pos[b, v].float()
raster_settings = GaussianRasterizationSettings(
image_height=self.opt.output_size,
image_width=self.opt.output_size,
tanfovx=self.tan_half_fov,
tanfovy=self.tan_half_fov,
bg=self.bg_color if bg_color is None else bg_color,
scale_modifier=scale_modifier,
viewmatrix=view_matrix,
projmatrix=view_proj_matrix,
sh_degree=0,
campos=campos,
prefiltered=False,
debug=False,
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
means3D=means3D,
means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),
shs=None,
colors_precomp=rgbs,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=None,
)
rendered_image = rendered_image.clamp(0, 1)
images.append(rendered_image)
alphas.append(rendered_alpha)
images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
return {
"image": images, # [B, V, 3, H, W]
"alpha": alphas, # [B, V, 1, H, W]
}
def save_ply(self, gaussians, path, compatible=True):
# gaussians: [B, N, 14]
# compatible: save pre-activated gaussians as in the original paper
assert gaussians.shape[0] == 1, 'only support batch size 1'
from plyfile import PlyData, PlyElement
means3D = gaussians[0, :, 0:3].contiguous().float()
opacity = gaussians[0, :, 3:4].contiguous().float()
scales = gaussians[0, :, 4:7].contiguous().float()
rotations = gaussians[0, :, 7:11].contiguous().float()
shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
# prune by opacity
mask = opacity.squeeze(-1) >= 0.005
means3D = means3D[mask]
opacity = opacity[mask]
scales = scales[mask]
rotations = rotations[mask]
shs = shs[mask]
# invert activation to make it compatible with the original ply format
if compatible:
opacity = kiui.op.inverse_sigmoid(opacity)
scales = torch.log(scales + 1e-8)
shs = (shs - 0.5) / 0.28209479177387814
xyzs = means3D.detach().cpu().numpy()
f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = opacity.detach().cpu().numpy()
scales = scales.detach().cpu().numpy()
rotations = rotations.detach().cpu().numpy()
l = ['x', 'y', 'z']
# All channels except the 3 DC
for i in range(f_dc.shape[1]):
l.append('f_dc_{}'.format(i))
l.append('opacity')
for i in range(scales.shape[1]):
l.append('scale_{}'.format(i))
for i in range(rotations.shape[1]):
l.append('rot_{}'.format(i))
dtype_full = [(attribute, 'f4') for attribute in l]
elements = np.empty(xyzs.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
def load_ply(self, path, compatible=True):
from plyfile import PlyData, PlyElement
plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
print("Number of points at loading : ", xyz.shape[0])
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
shs = np.zeros((xyz.shape[0], 3))
shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
gaussians = torch.from_numpy(gaussians).float() # cpu
if compatible:
gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
return gaussians
================================================
FILE: core/masactrl.py
================================================
import os
import torch
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from core.masactrl_utils import AttentionBase
from torchvision.utils import save_image
class MutualSelfAttentionControl(AttentionBase):
MODEL_TYPE = {
"SD": 16,
"SDXL": 70
}
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
model_type: the model type, SD or SDXL
"""
super().__init__()
self.total_steps = total_steps
self.total_layers = self.MODEL_TYPE.get(model_type, 16)
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
print("MasaCtrl at denoising steps: ", self.step_idx)
print("MasaCtrl at U-Net layers: ", self.layer_idx)
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Performing attention for a batch of queries, keys, and values
"""
b = q.shape[0] // num_heads
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
attn = sim.softmax(-1)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
qu, qc = q.chunk(2)
ku, kc = k.chunk(2)
vu, vc = v.chunk(2)
attnu, attnc = attn.chunk(2)
out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
out = torch.cat([out_u, out_c], dim=0)
return out
class MutualSelfAttention3DControl(AttentionBase):
MODEL_TYPE = {
"SD": 16,
"SDXL": 70
}
def __init__(self, start_steps=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
model_type: the model type, SD or SDXL
"""
super().__init__()
self.total_steps = total_steps
self.total_layers = self.MODEL_TYPE.get(model_type, 16)
self.start_step = start_steps
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
self.step_idx = step_idx if step_idx is not None else list(range(start_steps, total_steps))
print("MasaCtrl at denoising steps: ", self.step_idx)
print("MasaCtrl at U-Net layers: ", self.layer_idx)
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Performing attention for a batch of queries, keys, and values
"""
b = q.shape[0] // num_heads
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
attn = sim.softmax(-1)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
# qu, qc = q.chunk(2)
# ku, kc = k.chunk(2)
# vu, vc = v.chunk(2)
# attnu, attnc = attn.chunk(2)
q_t1, q_t2, q_t3, q_t4, q_s= q.chunk(5)
k_t1, k_t2, k_t3, k_t4, k_s = k.chunk(5)
v_t1, v_t2, v_t3, v_t4, v_s= v.chunk(5)
attn_t1, attn_t2, attn_t3, attn_t4, attn_s= attn.chunk(5)
out_s = super().forward(q_s, k_s, v_s, sim, attn_s, is_cross, place_in_unet, num_heads, **kwargs)
out_t1 = self.attn_batch(q_t1, torch.cat([k_s, k_t1]), torch.cat([v_s, v_t1]), sim[:num_heads], attn_t1, is_cross, place_in_unet, num_heads, **kwargs)
out_t2 = self.attn_batch(q_t2, torch.cat([k_s, k_t2]), torch.cat([v_s, v_t2]), sim[:num_heads], attn_t2, is_cross, place_in_unet, num_heads, **kwargs)
out_t3 = self.attn_batch(q_t3, torch.cat([k_s, k_t3]), torch.cat([v_s, v_t3]), sim[:num_heads], attn_t3, is_cross, place_in_unet, num_heads, **kwargs)
out_t4 = self.attn_batch(q_t4, torch.cat([k_s, k_t4]), torch.cat([v_s, v_t4]), sim[:num_heads], attn_t4, is_cross, place_in_unet, num_heads, **kwargs)
print(1)
# out_t1 = self.attn_batch(q_t1, k_s, v_s, sim[:num_heads], attn_t1, is_cross, place_in_unet, num_heads, **kwargs)
# out_t2 = self.attn_batch(q_t2, k_s, v_s, sim[:num_heads], attn_t2, is_cross, place_in_unet, num_heads, **kwargs)
# out_t3 = self.attn_batch(q_t3, k_s, v_s, sim[:num_heads], attn_t3, is_cross, place_in_unet, num_heads, **kwargs)
# out_t4 = self.attn_batch(q_t4, k_s, v_s, sim[:num_heads], attn_t4, is_cross, place_in_unet, num_heads, **kwargs)
# out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
# out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
out = torch.cat([out_t1, out_t2, out_t3, out_t4, out_s], dim=0)
return out
class MutualSelfAttentionControlUnion(MutualSelfAttentionControl):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
"""
Mutual self-attention control for Stable-Diffusion model with unition source and target [K, V]
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
model_type: the model type, SD or SDXL
"""
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
qu_s, qu_t, qc_s, qc_t = q.chunk(4)
ku_s, ku_t, kc_s, kc_t = k.chunk(4)
vu_s, vu_t, vc_s, vc_t = v.chunk(4)
attnu_s, attnu_t, attnc_s, attnc_t = attn.chunk(4)
# source image branch
out_u_s = super().forward(qu_s, ku_s, vu_s, sim, attnu_s, is_cross, place_in_unet, num_heads, **kwargs)
out_c_s = super().forward(qc_s, kc_s, vc_s, sim, attnc_s, is_cross, place_in_unet, num_heads, **kwargs)
# target image branch, concatenating source and target [K, V]
out_u_t = self.attn_batch(qu_t, torch.cat([ku_s, ku_t]), torch.cat([vu_s, vu_t]), sim[:num_heads], attnu_t, is_cross, place_in_unet, num_heads, **kwargs)
out_c_t = self.attn_batch(qc_t, torch.cat([kc_s, kc_t]), torch.cat([vc_s, vc_t]), sim[:num_heads], attnc_t, is_cross, place_in_unet, num_heads, **kwargs)
out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0)
return out
class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None, model_type="SD"):
"""
Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
mask_s: source mask with shape (h, w)
mask_t: target mask with same shape as source mask
mask_save_dir: the path to save the mask image
model_type: the model type, SD or SDXL
"""
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
self.mask_s = mask_s # source mask with shape (h, w)
self.mask_t = mask_t # target mask with same shape as source mask
print("Using mask-guided MasaCtrl")
if mask_save_dir is not None:
os.makedirs(mask_save_dir, exist_ok=True)
save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_s.png"))
save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_t.png"))
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
B = q.shape[0] // num_heads
H = W = int(np.sqrt(q.shape[1]))
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
if kwargs.get("is_mask_attn") and self.mask_s is not None:
print("masked attention")
mask = self.mask_s.unsqueeze(0).unsqueeze(0)
mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0)
mask = mask.flatten()
# background
sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
# object
sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
sim = torch.cat([sim_fg, sim_bg], dim=0)
attn = sim.softmax(-1)
if len(attn) == 2 * len(v):
v = torch.cat([v] * 2)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
B = q.shape[0] // num_heads // 2
H = W = int(np.sqrt(q.shape[1]))
qu, qc = q.chunk(2)
ku, kc = k.chunk(2)
vu, vc = v.chunk(2)
attnu, attnc = attn.chunk(2)
out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
if self.mask_s is not None and self.mask_t is not None:
out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0)
out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0)
mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W))
mask = mask.reshape(-1, 1) # (hw, 1)
out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask)
out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask)
out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
return out
class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None, model_type="SD"):
"""
MasaCtrl with mask auto generation from cross-attention map
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
thres: the thereshold for mask thresholding
ref_token_idx: the token index list for cross-attention map aggregation
cur_token_idx: the token index list for cross-attention map aggregation
mask_save_dir: the path to save the mask image
"""
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
print("Using MutualSelfAttentionControlMaskAuto")
self.thres = thres
self.ref_token_idx = ref_token_idx
self.cur_token_idx = cur_token_idx
self.self_attns = []
self.cross_attns = []
self.cross_attns_mask = None
self.self_attns_mask = None
self.mask_save_dir = mask_save_dir
if self.mask_save_dir is not None:
os.makedirs(self.mask_save_dir, exist_ok=True)
def after_step(self):
self.self_attns = []
self.cross_attns = []
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Performing attention for a batch of queries, keys, and values
"""
B = q.shape[0] // num_heads
H = W = int(np.sqrt(q.shape[1]))
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
if self.self_attns_mask is not None:
# binarize the mask
mask = self.self_attns_mask
thres = self.thres
mask[mask >= thres] = 1
mask[mask < thres] = 0
sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
sim = torch.cat([sim_fg, sim_bg])
attn = sim.softmax(-1)
if len(attn) == 2 * len(v):
v = torch.cat([v] * 2)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
return out
def aggregate_cross_attn_map(self, idx):
attn_map = torch.stack(self.cross_attns, dim=1).mean(1) # (B, N, dim)
B = attn_map.shape[0]
res = int(np.sqrt(attn_map.shape[-2]))
attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1])
image = attn_map[..., idx]
if isinstance(idx, list):
image = image.sum(-1)
image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]
image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0]
image = (image - image_min) / (image_max - image_min)
return image
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross:
# save cross attention map with res 16 * 16
if attn.shape[1] == 16 * 16:
self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1))
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
B = q.shape[0] // num_heads // 2
H = W = int(np.sqrt(q.shape[1]))
qu, qc = q.chunk(2)
ku, kc = k.chunk(2)
vu, vc = v.chunk(2)
attnu, attnc = attn.chunk(2)
out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
if len(self.cross_attns) == 0:
self.self_attns_mask = None
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
else:
mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx) # (2, H, W)
mask_source = mask[-2] # (H, W)
res = int(np.sqrt(q.shape[1]))
self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten()
if self.mask_save_dir is not None:
H = W = int(np.sqrt(self.self_attns_mask.shape[0]))
mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0)
save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_s_{self.cur_step}_{self.cur_att_layer}.png"))
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
if self.self_attns_mask is not None:
mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (2, H, W)
mask_target = mask[-1] # (H, W)
res = int(np.sqrt(q.shape[1]))
spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1)
if self.mask_save_dir is not None:
H = W = int(np.sqrt(spatial_mask.shape[0]))
mask_image = spatial_mask.reshape(H, W).unsqueeze(0)
save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_t_{self.cur_step}_{self.cur_att_layer}.png"))
# binarize the mask
thres = self.thres
spatial_mask[spatial_mask >= thres] = 1
spatial_mask[spatial_mask < thres] = 0
out_u_target_fg, out_u_target_bg = out_u_target.chunk(2)
out_c_target_fg, out_c_target_bg = out_c_target.chunk(2)
out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask)
out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask)
# set self self-attention mask to None
self.self_attns_mask = None
out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
return out
================================================
FILE: core/masactrl_utils.py
================================================
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Tuple, List, Callable, Dict
from torchvision.utils import save_image
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class AttentionStore(AttentionBase):
def __init__(self, res=[32], min_step=0, max_step=1000):
super().__init__()
self.res = res
self.min_step = min_step
self.max_step = max_step
self.valid_steps = 0
self.self_attns = [] # store the all attns
self.cross_attns = []
self.self_attns_step = [] # store the attns in each step
self.cross_attns_step = []
def after_step(self):
if self.cur_step > self.min_step and self.cur_step < self.max_step:
self.valid_steps += 1
if len(self.self_attns) == 0:
self.self_attns = self.self_attns_step
self.cross_attns = self.cross_attns_step
else:
for i in range(len(self.self_attns)):
self.self_attns[i] += self.self_attns_step[i]
self.cross_attns[i] += self.cross_attns_step[i]
self.self_attns_step.clear()
self.cross_attns_step.clear()
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
if attn.shape[1] <= 64 ** 2: # avoid OOM
if is_cross:
self.cross_attns_step.append(attn)
else:
self.self_attns_step.append(attn)
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
def regiter_attention_editor_diffusers(unet, editor: AttentionBase):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def ca_forward(self, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = self.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = self.to_out[0]
else:
to_out = self.to_out
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
# the only difference
out = editor(
q, k, v, sim, attn, is_cross, place_in_unet,
self.heads, scale=self.scale)
return to_out(out)
return forward
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
net.forward = ca_forward(net, place_in_unet)
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in unet.named_children():
if "down" in net_name:
cross_att_count += register_editor(net, 0, "down")
elif "mid" in net_name:
cross_att_count += register_editor(net, 0, "mid")
elif "up" in net_name:
cross_att_count += register_editor(net, 0, "up")
editor.num_att_layers = cross_att_count
def regiter_attention_editor_ldm(model, editor: AttentionBase):
"""
Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
"""
def ca_forward(self, place_in_unet):
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = self.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = self.to_out[0]
else:
to_out = self.to_out
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
# the only difference
out = editor(
q, k, v, sim, attn, is_cross, place_in_unet,
self.heads, scale=self.scale)
return to_out(out)
return forward
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
net.forward = ca_forward(net, place_in_unet)
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.model.diffusion_model.named_children():
if "input" in net_name:
cross_att_count += register_editor(net, 0, "input")
elif "middle" in net_name:
cross_att_count += register_editor(net, 0, "middle")
elif "output" in net_name:
cross_att_count += register_editor(net, 0, "output")
editor.num_att_layers = cross_att_count
================================================
FILE: core/models/__init__.py
================================================
================================================
FILE: core/models/transformer_mv2d.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import ImagePositionalEmbeddings
from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph
from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
import pdb
import random
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 0:
return nn.Linear(*args, **kwargs)
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def my_repeat(tensor, num_repeats):
"""
Repeat a tensor along a given dimension
"""
if len(tensor.shape) == 3:
return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
elif len(tensor.shape) == 4:
return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
@dataclass
class TransformerMV2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class TransformerMV2DModel(ModelMixin, ConfigMixin):
"""
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
num_views: int = 1,
cd_attention_last: bool=False,
cd_attention_mid: bool=False,
multiview_attention: bool=True,
sparse_mv_attention: bool = False,
mvcd_attention: bool=False
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicMVTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
num_views=num_views,
cd_attention_last=cd_attention_last,
cd_attention_mid=cd_attention_mid,
multiview_attention=multiview_attention,
sparse_mv_attention=sparse_mv_attention,
mvcd_attention=mvcd_attention
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches:
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
self.post_init()
def post_init(self):
conv_block = self.proj_in
conv_params = {
k: getattr(conv_block, k)
for k in [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
]
}
conv_params["in_channels"] += 6
conv_params["dims"] = 2
conv_params["device"] = conv_block.weight.device
inflated_proj_in = conv_nd(**conv_params)
inp_weight = conv_block.weight.data
feat_shape = inp_weight.shape
feat_weight = torch.zeros(
(feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device
)
inflated_proj_in.weight.data.copy_(
torch.cat([inp_weight, feat_weight], dim=1)
)
inflated_proj_in.bias.data.copy_(conv_block.bias.data)
self.proj_in = inflated_proj_in
self.post_intialized = True
def post_linear_init(self):
linear_block = self.proj_in
linear_params = {
k: getattr(linear_block, k)
for k in [
"in_features",
"out_features"
]
}
linear_params["in_features"] += 6
linear_params["dims"] = 0
linear_params["device"] = linear_block.weight.device
inflated_proj_in = conv_nd(**linear_params)
inp_weight = linear_block.weight.data
feat_shape = inp_weight.shape
feat_weight = torch.zeros(
(feat_shape[0], 6), device=inp_weight.device
)
inflated_proj_in.weight.data.copy_(
torch.cat([inp_weight, feat_weight], dim=1)
)
inflated_proj_in.bias.data.copy_(linear_block.bias.data)
self.proj_in = inflated_proj_in
self.post_intialized = True
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
ray_embedding: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.post_intialized:
#ray_embedding = rearrange(ray_embedding, "n v c h w -> (n v) c h w")
ray_embedding_interpolated = F.interpolate(ray_embedding, size=hidden_states.shape[-2:], align_corners=False, mode="bilinear")
#ray_embedding_interpolated = rearrange(ray_embedding_interpolated, "(n v) c h w -> n v c h w", v=4)
# concat plucker to x
hidden_states = torch.cat([hidden_states, ray_embedding_interpolated], dim=1)
#hidden_states = rearrange(hidden_states, "n v c h w -> (n v) c h w")
# x = self.proj_in(x)
# x = rearrange(x, "(n v) c h w -> n v c h w", v=4)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
inner_dim = inner_dim -6
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return TransformerMV2DModelOutput(sample=output)
@maybe_allow_in_graph
class BasicMVTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
num_views: int = 1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool = False
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.multiview_attention = multiview_attention
self.sparse_mv_attention = sparse_mv_attention
self.mvcd_attention = mvcd_attention
self.attn1 = CustomAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
processor=MVAttnProcessor()
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
self.num_views = num_views
self.cd_attention_last = cd_attention_last
if self.cd_attention_last:
# Joint task -Attn
self.attn_joint_last = CustomJointAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
processor=JointAttnProcessor()
)
nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
self.cd_attention_mid = cd_attention_mid
if self.cd_attention_mid:
# print("cross-domain attn in the middle")
# Joint task -Attn
self.attn_joint_mid = CustomJointAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
processor=JointAttnProcessor()
)
nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
assert attention_mask is None # not supported yet
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
num_views=self.num_views,
multiview_attention=self.multiview_attention,
sparse_mv_attention=self.sparse_mv_attention,
mvcd_attention=self.mvcd_attention,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# joint attention twice
if self.cd_attention_mid:
norm_hidden_states = (
self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
)
hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
if self.cd_attention_last:
norm_hidden_states = (
self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
)
hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
return hidden_states
class CustomAttention(Attention):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
):
processor = XFormersMVAttnProcessor()
self.set_processor(processor)
# print("using xformers attention processor")
class CustomJointAttention(Attention):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
):
processor = XFormersJointAttnProcessor()
self.set_processor(processor)
# print("using xformers attention processor")
class MVAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_views=1,
multiview_attention=True
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# print('query', query.shape, 'key', key.shape, 'value', value.shape)
#([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
# pdb.set_trace()
# multi-view self-attention
if multiview_attention:
key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class XFormersMVAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_views=1.,
multiview_attention=True,
sparse_mv_attention=False,
mvcd_attention=False,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# from yuancheng; here attention_mask is None
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key_raw = attn.to_k(encoder_hidden_states)
value_raw = attn.to_v(encoder_hidden_states)
# print('query', query.shape, 'key', key.shape, 'value', value.shape)
#([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
# pdb.set_trace()
# multi-view self-attention
if multiview_attention:
if not sparse_mv_attention:
key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
else:
key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
value = torch.cat([value_front, value_raw], dim=1)
else:
# print("don't use multiview attention.")
key = key_raw
value = value_raw
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class XFormersJointAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_tasks=2
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# from yuancheng; here attention_mask is None
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
assert num_tasks == 2 # only support two tasks now
key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class JointAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
num_tasks=2
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
assert num_tasks == 2 # only support two tasks now
key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
================================================
FILE: core/models/unet_mv2d_blocks.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import is_torch_version, logging
from diffusers.models.attention import AdaGroupNorm
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from diffusers.models.dual_transformer_2d import DualTransformer2DModel
from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_mv2d import TransformerMV2DModel
from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class IdentityMLP(nn.Module):
def __init__(self, size):
super(IdentityMLP, self).__init__()
self.linear = nn.Linear(size, size)
self.init_identity()
def forward(self, x):
return self.linear(x)
def init_identity(self):
# Initialize the weights to an identity matrix and biases to zero
identity_matrix = torch.eye(self.linear.in_features)
self.linear.weight.data.copy_(identity_matrix)
self.linear.bias.data.zero_()
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
transformer_layers_per_block=1,
num_attention_heads=None,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
attention_head_dim=None,
downsample_type=None,
num_views=1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool=False
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
return DownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "ResnetDownsampleBlock2D":
return ResnetDownsampleBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "AttnDownBlock2D":
if add_downsample is False:
downsample_type = None
else:
downsample_type = downsample_type or "conv" # default to 'conv'
return AttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
)
elif down_block_type == "CrossAttnDownBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
return CrossAttnDownBlock2D(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
# custom MV2D attention block
elif down_block_type == "CrossAttnDownBlockMV2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
return CrossAttnDownBlockMV2D(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
num_views=num_views,
cd_attention_last=cd_attention_last,
cd_attention_mid=cd_attention_mid,
multiview_attention=multiview_attention,
sparse_mv_attention=sparse_mv_attention,
mvcd_attention=mvcd_attention
)
elif down_block_type == "SimpleCrossAttnDownBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
return SimpleCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "AttnSkipDownBlock2D":
return AttnSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "DownEncoderBlock2D":
return DownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "AttnDownEncoderBlock2D":
return AttnDownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "KDownBlock2D":
return KDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif down_block_type == "KCrossAttnDownBlock2D":
return KCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
add_self_attention=True if not add_downsample else False,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
transformer_layers_per_block=1,
num_attention_heads=None,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
attention_head_dim=None,
upsample_type=None,
num_views=1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool=False
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
)
attention_head_dim = num_attention_heads
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
return UpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "ResnetUpsampleBlock2D":
return ResnetUpsampleBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
return CrossAttnUpBlock2D(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
# custom MV2D attention block
elif up_block_type == "CrossAttnUpBlockMV2D":
# if cross_attention_dim is None:
# raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
return CrossAttnUpBlockMV2D(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
num_views=num_views,
cd_attention_last=cd_attention_last,
cd_attention_mid=cd_attention_mid,
multiview_attention=multiview_attention,
sparse_mv_attention=sparse_mv_attention,
mvcd_attention=mvcd_attention
)
elif up_block_type == "SimpleCrossAttnUpBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
return SimpleCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif up_block_type == "AttnUpBlock2D":
if add_upsample is False:
upsample_type = None
else:
upsample_type = upsample_type or "conv" # default to 'conv'
return AttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
)
elif up_block_type == "SkipUpBlock2D":
return SkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "AttnSkipUpBlock2D":
return AttnSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "UpDecoderBlock2D":
return UpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
elif up_block_type == "KUpBlock2D":
return KUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "KCrossAttnUpBlock2D":
return KCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlockMV2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
num_views: int = 1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool=False
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
if not dual_cross_attention:
attentions.append(
TransformerMV2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
num_views=num_views,
cd_attention_last=cd_attention_last,
cd_attention_mid=cd_attention_mid,
multiview_attention=multiview_attention,
sparse_mv_attention=sparse_mv_attention,
mvcd_attention=mvcd_attention
)
)
else:
raise NotImplementedError
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
ray_embedding: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
ray_embedding=ray_embedding,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CrossAttnUpBlockMV2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
num_views: int = 1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool=False
):
super().__init__()
resnets = []
attentions = []
mlps = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
mlps.append(IdentityMLP(res_skip_channels))
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
TransformerMV2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
num_views=num_views,
cd_attention_last=cd_attention_last,
cd_attention_mid=cd_attention_mid,
multiview_attention=multiview_attention,
sparse_mv_attention=sparse_mv_attention,
mvcd_attention=mvcd_attention
)
)
else:
raise NotImplementedError
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.mlps = nn.ModuleList(mlps)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
ray_embedding: Optional[torch.Tensor] = None,
):
for resnet, attn, mlp in zip(self.resnets, self.attentions, self.mlps):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
B, _, H, W = res_hidden_states.shape
res_hidden_states = res_hidden_states.permute(0, 2, 3, 1).reshape(B, H * W, _)
res_hidden_states = mlp(res_hidden_states)
res_hidden_states = res_hidden_states.reshape(B, H, W, _).permute(0, 3, 1, 2).contiguous()
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
ray_embedding,
**ckpt_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
ray_embedding = ray_embedding,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class CrossAttnDownBlockMV2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
num_views: int = 1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool=False
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
TransformerMV2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
num_views=num_views,
cd_attention_last=cd_attention_last,
cd_attention_mid=cd_attention_mid,
multiview_attention=multiview_attention,
sparse_mv_attention=sparse_mv_attention,
mvcd_attention=mvcd_attention
)
)
else:
raise NotImplementedError
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None,
):
output_states = ()
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
================================================
FILE: core/models/unet_mv2d_condition.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import os
import torch
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import (
GaussianFourierProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
UpBlock2D,
)
from diffusers.utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
logging,
)
from diffusers import __version__
from .unet_mv2d_blocks import (
CrossAttnDownBlockMV2D,
CrossAttnUpBlockMV2D,
UNetMidBlockMV2DCrossAttn,
get_down_block,
get_up_block,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNetMV2DConditionOutput(BaseOutput):
"""
The output of [`UNet2DConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
Dimension for the timestep embeddings.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use only once on the time embeddings before they are passed to the rest of
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
otherwise.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockMV2D",
"CrossAttnDownBlockMV2D",
"CrossAttnDownBlockMV2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
num_views: int = 1,
cd_attention_last: bool = False,
cd_attention_mid: bool = False,
multiview_attention: bool = True,
sparse_mv_attention: bool = False,
mvcd_attention: bool = False
):
super().__init__()
self.sample_size = sample_size
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_di
gitextract_hayw1hga/
├── LICENSE
├── README.md
├── acc_configs/
│ ├── gpu1.yaml
│ ├── gpu4.yaml
│ ├── gpu6.yaml
│ ├── gpu7.yaml
│ ├── gpu8.yaml
│ ├── hostfile
│ ├── multi_node.yaml
│ └── zero2.json
├── core/
│ ├── __init__.py
│ ├── attention.py
│ ├── control.py
│ ├── diffuser_utils.py
│ ├── gs.py
│ ├── masactrl.py
│ ├── masactrl_utils.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── transformer_mv2d.py
│ │ ├── unet_mv2d_blocks.py
│ │ ├── unet_mv2d_condition.py
│ │ ├── unet_mv2d_condition_depth.py
│ │ ├── unet_mv2d_condition_depth_diffusion.py
│ │ └── unet_mv2d_condition_depth_diffusion_test.py
│ ├── models_LGM_compos_diffusion.py
│ ├── models_LGM_compos_diffusion_validate_inversion_2_masa.py
│ ├── options_latents_diffusion.py
│ ├── provider_Gobjaverse_latent_diffusion_insert.py
│ ├── unet_LGM_compos.py
│ └── utils.py
├── infer_ours_masa.py
├── main_resume_compose.py
└── mvdream/
├── mv_unet.py
└── pipeline_mvdream.py
SYMBOL INDEX (268 symbols across 22 files)
FILE: core/attention.py
class Attention (line 31) | class Attention(nn.Module):
method __init__ (line 32) | def __init__(
method forward (line 51) | def forward(self, x: Tensor) -> Tensor:
class MemEffAttention (line 67) | class MemEffAttention(Attention):
method forward (line 68) | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
class CrossAttention (line 87) | class CrossAttention(nn.Module):
method __init__ (line 88) | def __init__(
method forward (line 113) | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
class MemEffCrossAttention (line 137) | class MemEffCrossAttention(CrossAttention):
method forward (line 138) | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> ...
FILE: core/control.py
function retrieve_timesteps (line 54) | def retrieve_timesteps(
class ControlNetPipeline (line 97) | class ControlNetPipeline(StableDiffusionControlNetPipeline):
method pred_x0 (line 99) | def pred_x0(
method next_step (line 120) | def next_step(
method __call__ (line 144) | def __call__(
FILE: core/diffuser_utils.py
class MasaCtrlPipeline (line 22) | class MasaCtrlPipeline(StableDiffusionPipeline):
method next_step (line 24) | def next_step(
method step (line 47) | def step(
method image2latent (line 68) | def image2latent(self, image):
method latent2image (line 80) | def latent2image(self, latents, return_type='np'):
method latent2image_grad (line 92) | def latent2image_grad(self, latents):
method __call__ (line 99) | def __call__(
method invert (line 201) | def invert(
FILE: core/gs.py
class GaussianRenderer (line 16) | class GaussianRenderer:
method __init__ (line 17) | def __init__(self, opt: Options):
method render (line 31) | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color...
method save_ply (line 101) | def save_ply(self, gaussians, path, compatible=True):
method load_ply (line 154) | def load_ply(self, path, compatible=True):
FILE: core/masactrl.py
class MutualSelfAttentionControl (line 14) | class MutualSelfAttentionControl(AttentionBase):
method __init__ (line 20) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_...
method attn_batch (line 41) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method forward (line 56) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttention3DControl (line 74) | class MutualSelfAttention3DControl(AttentionBase):
method __init__ (line 80) | def __init__(self, start_steps=4, start_layer=10, layer_idx=None, step...
method attn_batch (line 101) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method forward (line 116) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttentionControlUnion (line 150) | class MutualSelfAttentionControlUnion(MutualSelfAttentionControl):
method __init__ (line 151) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_...
method forward (line 164) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttentionControlMask (line 189) | class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
method __init__ (line 190) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step...
method attn_batch (line 213) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method forward (line 238) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttentionControlMaskAuto (line 271) | class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
method __init__ (line 272) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_...
method after_step (line 302) | def after_step(self):
method attn_batch (line 306) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method aggregate_cross_attn_map (line 335) | def aggregate_cross_attn_map(self, idx):
method forward (line 348) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
FILE: core/masactrl_utils.py
class AttentionBase (line 14) | class AttentionBase:
method __init__ (line 15) | def __init__(self):
method after_step (line 20) | def after_step(self):
method __call__ (line 23) | def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_he...
method forward (line 33) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
method reset (line 38) | def reset(self):
class AttentionStore (line 43) | class AttentionStore(AttentionBase):
method __init__ (line 44) | def __init__(self, res=[32], min_step=0, max_step=1000):
method after_step (line 57) | def after_step(self):
method forward (line 70) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
function regiter_attention_editor_diffusers (line 79) | def regiter_attention_editor_diffusers(unet, editor: AttentionBase):
function regiter_attention_editor_ldm (line 147) | def regiter_attention_editor_ldm(model, editor: AttentionBase):
FILE: core/models/transformer_mv2d.py
function conv_nd (line 34) | def conv_nd(dims, *args, **kwargs):
function my_repeat (line 54) | def my_repeat(tensor, num_repeats):
class TransformerMV2DModelOutput (line 65) | class TransformerMV2DModelOutput(BaseOutput):
class TransformerMV2DModel (line 78) | class TransformerMV2DModel(ModelMixin, ConfigMixin):
method __init__ (line 107) | def __init__(
method post_init (line 255) | def post_init(self):
method post_linear_init (line 283) | def post_linear_init(self):
method forward (line 308) | def forward(
class BasicMVTransformerBlock (line 461) | class BasicMVTransformerBlock(nn.Module):
method __init__ (line 482) | def __init__(
method set_chunk_feed_forward (line 610) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
method forward (line 615) | def forward(
class CustomAttention (line 711) | class CustomAttention(Attention):
method set_use_memory_efficient_attention_xformers (line 712) | def set_use_memory_efficient_attention_xformers(
class CustomJointAttention (line 720) | class CustomJointAttention(Attention):
method set_use_memory_efficient_attention_xformers (line 721) | def set_use_memory_efficient_attention_xformers(
class MVAttnProcessor (line 728) | class MVAttnProcessor:
method __call__ (line 733) | def __call__(
class XFormersMVAttnProcessor (line 804) | class XFormersMVAttnProcessor:
method __call__ (line 809) | def __call__(
class XFormersJointAttnProcessor (line 904) | class XFormersJointAttnProcessor:
method __call__ (line 909) | def __call__(
class JointAttnProcessor (line 992) | class JointAttnProcessor:
method __call__ (line 997) | def __call__(
FILE: core/models/unet_mv2d_blocks.py
class IdentityMLP (line 34) | class IdentityMLP(nn.Module):
method __init__ (line 35) | def __init__(self, size):
method forward (line 40) | def forward(self, x):
method init_identity (line 44) | def init_identity(self):
function get_down_block (line 51) | def get_down_block(
function get_up_block (line 281) | def get_up_block(
class UNetMidBlockMV2DCrossAttn (line 515) | class UNetMidBlockMV2DCrossAttn(nn.Module):
method __init__ (line 516) | def __init__(
method forward (line 604) | def forward(
class CrossAttnUpBlockMV2D (line 630) | class CrossAttnUpBlockMV2D(nn.Module):
method __init__ (line 631) | def __init__(
method forward (line 720) | def forward(
class CrossAttnDownBlockMV2D (line 791) | class CrossAttnDownBlockMV2D(nn.Module):
method __init__ (line 792) | def __init__(
method forward (line 882) | def forward(
FILE: core/models/unet_mv2d_condition.py
class UNetMV2DConditionOutput (line 76) | class UNetMV2DConditionOutput(BaseOutput):
class UNetMV2DConditionModel (line 88) | class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoa...
method __init__ (line 179) | def __init__(
method attn_processors (line 634) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 657) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
method set_default_attn_processor (line 691) | def set_default_attn_processor(self):
method set_attention_slice (line 697) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 762) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 766) | def forward(
method from_pretrained_2d (line 1063) | def from_pretrained_2d(
method _load_pretrained_model_2d (line 1394) | def _load_pretrained_model_2d(
FILE: core/models/unet_mv2d_condition_depth.py
class UNetMV2DConditionOutput (line 76) | class UNetMV2DConditionOutput(BaseOutput):
class UNetMV2DConditionModel (line 88) | class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoa...
method __init__ (line 179) | def __init__(
method attn_processors (line 634) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 657) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
method set_default_attn_processor (line 691) | def set_default_attn_processor(self):
method set_attention_slice (line 697) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 762) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 766) | def forward(
method from_pretrained_2d (line 1063) | def from_pretrained_2d(
method _load_pretrained_model_2d (line 1394) | def _load_pretrained_model_2d(
FILE: core/models/unet_mv2d_condition_depth_diffusion.py
class IdentityMLP (line 74) | class IdentityMLP(nn.Module):
method __init__ (line 75) | def __init__(self, size):
method forward (line 80) | def forward(self, x):
method init_identity (line 83) | def init_identity(self):
class UNetMV2DConditionOutput (line 90) | class UNetMV2DConditionOutput(BaseOutput):
class UNetMV2DConditionModel (line 102) | class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoa...
method __init__ (line 193) | def __init__(
method attn_processors (line 648) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 671) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
method set_default_attn_processor (line 705) | def set_default_attn_processor(self):
method set_attention_slice (line 711) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 776) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 780) | def forward(
method from_pretrained_2d (line 1080) | def from_pretrained_2d(
method _load_pretrained_model_2d (line 1413) | def _load_pretrained_model_2d(
FILE: core/models/unet_mv2d_condition_depth_diffusion_test.py
class IdentityMLP (line 74) | class IdentityMLP(nn.Module):
method __init__ (line 75) | def __init__(self, size):
method forward (line 80) | def forward(self, x):
method init_identity (line 83) | def init_identity(self):
class UNetMV2DConditionOutput (line 90) | class UNetMV2DConditionOutput(BaseOutput):
class UNetMV2DConditionModel (line 102) | class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoa...
method __init__ (line 193) | def __init__(
method attn_processors (line 648) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 671) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
method set_default_attn_processor (line 705) | def set_default_attn_processor(self):
method set_attention_slice (line 711) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 776) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 780) | def forward(
method from_pretrained_2d (line 1080) | def from_pretrained_2d(
method _load_pretrained_model_2d (line 1414) | def _load_pretrained_model_2d(
FILE: core/models_LGM_compos_diffusion.py
class LGM (line 20) | class LGM(nn.Module):
method __init__ (line 21) | def __init__(
method state_dict (line 68) | def state_dict(self, **kwargs):
method prepare_default_rays (line 77) | def prepare_default_rays(self, device, elevation=0):
method forward_gaussians (line 104) | def forward_gaussians(self, images, encoder_hidden_states, data):
method pred_x0 (line 143) | def pred_x0(
method encode_prompt (line 164) | def encode_prompt(
method compute_snr (line 216) | def compute_snr(self, timesteps):
method forward (line 240) | def forward(self, data, step_ratio=1):
FILE: core/models_LGM_compos_diffusion_validate_inversion_2_masa.py
class LGM (line 25) | class LGM(nn.Module):
method __init__ (line 26) | def __init__(
method state_dict (line 85) | def state_dict(self, **kwargs):
method prepare_default_rays (line 94) | def prepare_default_rays(self, device, elevation=0, proj_matrix=None):
method prepare_default_rays_zero123 (line 126) | def prepare_default_rays_zero123(self, device, elevation=0, proj_matri...
method unet_step (line 163) | def unet_step(
method forward_gaussians (line 183) | def forward_gaussians(self, images, encoder_hidden_states, data, uncon...
method pred_x0 (line 231) | def pred_x0(
method step (line 252) | def step(
method encode_prompt (line 282) | def encode_prompt(
method compute_snr (line 334) | def compute_snr(self, timesteps):
method forward (line 358) | def forward(self, data, step_ratio=1):
method next_step (line 464) | def next_step(
method image2latent (line 500) | def image2latent(self, image):
method invert (line 512) | def invert(
method validate (line 596) | def validate(self, data, num_inference_steps=30, single_image=True):
FILE: core/options_latents_diffusion.py
class Options (line 7) | class Options:
FILE: core/provider_Gobjaverse_latent_diffusion_insert.py
class GobjaverseDataset (line 22) | class GobjaverseDataset(Dataset):
method _warn (line 24) | def _warn(self):
method __init__ (line 27) | def __init__(self, opt: Options, training=True):
method __len__ (line 59) | def __len__(self):
method __getitem__ (line 63) | def __getitem__(self, idx):
FILE: core/unet_LGM_compos.py
class MVAttention (line 11) | class MVAttention(nn.Module):
method __init__ (line 12) | def __init__(
method forward (line 35) | def forward(self, x):
class UnetAttention (line 51) | class UnetAttention(nn.Module):
method __init__ (line 52) | def __init__(
method post_init (line 78) | def post_init(self):
method forward (line 82) | def forward(self, x, unet_x):
class ResnetBlock (line 99) | class ResnetBlock(nn.Module):
method __init__ (line 100) | def __init__(
method post_init (line 137) | def post_init(self):
method forward (line 141) | def forward(self, x, temb=None):
class DownBlock (line 164) | class DownBlock(nn.Module):
method __init__ (line 165) | def __init__(
method forward (line 208) | def forward(self, x, unet_xs=None, temb=None):
class MidBlock (line 232) | class MidBlock(nn.Module):
method __init__ (line 233) | def __init__(
method forward (line 257) | def forward(self, x, temb=None):
class UpBlock (line 266) | class UpBlock(nn.Module):
method __init__ (line 267) | def __init__(
method forward (line 298) | def forward(self, x, xs, temb=None):
class UNet (line 316) | class UNet(nn.Module):
method __init__ (line 317) | def __init__(
method forward (line 386) | def forward(self, x, unet_xss=None, temb=None):
FILE: core/utils.py
function get_rays (line 10) | def get_rays(pose, h, w, fovy, opengl=True):
function orbit_camera_jitter (line 45) | def orbit_camera_jitter(poses, strength=0.1):
function grid_distortion (line 63) | def grid_distortion(images, strength=0.5):
FILE: infer_ours_masa.py
function process (line 72) | def process(opt: Options, path):
FILE: main_resume_compose.py
function main (line 17) | def main():
FILE: mvdream/mv_unet.py
function get_camera (line 20) | def get_camera(
function timestep_embedding (line 42) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 70) | def zero_module(module):
function conv_nd (line 79) | def conv_nd(dims, *args, **kwargs):
function avg_pool_nd (line 92) | def avg_pool_nd(dims, *args, **kwargs):
function default (line 105) | def default(val, d):
class GEGLU (line 111) | class GEGLU(nn.Module):
method __init__ (line 112) | def __init__(self, dim_in, dim_out):
method forward (line 116) | def forward(self, x):
class FeedForward (line 121) | class FeedForward(nn.Module):
method __init__ (line 122) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
method forward (line 136) | def forward(self, x):
class MemoryEfficientCrossAttention (line 140) | class MemoryEfficientCrossAttention(nn.Module):
method __init__ (line 142) | def __init__(
method forward (line 176) | def forward(self, x, context=None):
class BasicTransformerBlock3D (line 230) | class BasicTransformerBlock3D(nn.Module):
method __init__ (line 232) | def __init__(
method forward (line 267) | def forward(self, x, context=None, num_frames=1):
class SpatialTransformer3D (line 276) | class SpatialTransformer3D(nn.Module):
method __init__ (line 278) | def __init__(
method forward (line 318) | def forward(self, x, context=None, num_frames=1):
class PerceiverAttention (line 335) | class PerceiverAttention(nn.Module):
method __init__ (line 336) | def __init__(self, *, dim, dim_head=64, heads=8):
method forward (line 350) | def forward(self, x, latents):
class Resampler (line 386) | class Resampler(nn.Module):
method __init__ (line 387) | def __init__(
method forward (line 420) | def forward(self, x):
class CondSequential (line 431) | class CondSequential(nn.Sequential):
method forward (line 437) | def forward(self, x, emb, context=None, num_frames=1):
class Upsample (line 448) | class Upsample(nn.Module):
method __init__ (line 457) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 468) | def forward(self, x):
class Downsample (line 481) | class Downsample(nn.Module):
method __init__ (line 490) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 510) | def forward(self, x):
class ResBlock (line 515) | class ResBlock(nn.Module):
method __init__ (line 530) | def __init__(
method forward (line 592) | def forward(self, x, emb):
class MultiViewUNetModel (line 615) | class MultiViewUNetModel(ModelMixin, ConfigMixin):
method __init__ (line 645) | def __init__(
method forward (line 944) | def forward(
FILE: mvdream/pipeline_mvdream.py
class MVDreamPipeline (line 23) | class MVDreamPipeline(DiffusionPipeline):
method __init__ (line 27) | def __init__(
method enable_vae_slicing (line 84) | def enable_vae_slicing(self):
method disable_vae_slicing (line 93) | def disable_vae_slicing(self):
method enable_vae_tiling (line 100) | def enable_vae_tiling(self):
method disable_vae_tiling (line 109) | def disable_vae_tiling(self):
method enable_sequential_cpu_offload (line 116) | def enable_sequential_cpu_offload(self, gpu_id=0):
method enable_model_cpu_offload (line 140) | def enable_model_cpu_offload(self, gpu_id=0):
method _execution_device (line 170) | def _execution_device(self):
method _encode_prompt (line 187) | def _encode_prompt(
method decode_latents (line 339) | def decode_latents(self, latents):
method prepare_extra_step_kwargs (line 347) | def prepare_extra_step_kwargs(self, generator, eta):
method prepare_latents (line 368) | def prepare_latents(
method encode_image (line 402) | def encode_image(self, image, device, num_images_per_prompt):
method encode_image_latents (line 416) | def encode_image_latents(self, image, device, num_images_per_prompt):
method __call__ (line 432) | def __call__(
Condensed preview — 34 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (673K chars).
[
{
"path": "LICENSE",
"chars": 1082,
"preview": "MIT License\n\nCopyright (c) 2024 PKU-YUAN-Lab (袁粒课题组-北大信工)\n\nPermission is hereby granted, free of charge, to any person o"
},
{
"path": "README.md",
"chars": 2810,
"preview": "<h2 align=\"center\"> <a href=\"https://github.com/PKU-YuanGroup/Cycle3D\">Cycle3D: High-quality and Consistent Image-to-3D "
},
{
"path": "acc_configs/gpu1.yaml",
"chars": 395,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: 'NO'\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training"
},
{
"path": "acc_configs/gpu4.yaml",
"chars": 306,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_tra"
},
{
"path": "acc_configs/gpu6.yaml",
"chars": 392,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_tra"
},
{
"path": "acc_configs/gpu7.yaml",
"chars": 392,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_tra"
},
{
"path": "acc_configs/gpu8.yaml",
"chars": 392,
"preview": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_tra"
},
{
"path": "acc_configs/hostfile",
"chars": 58,
"preview": "gpu147 slots=8\ngpu176 slots=8\ngpu47 slots=8\ngpu117 slots=8"
},
{
"path": "acc_configs/multi_node.yaml",
"chars": 462,
"preview": "compute_environment: LOCAL_MACHINE\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n gradient_clipping: 1.0\n zero_stage: "
},
{
"path": "acc_configs/zero2.json",
"chars": 556,
"preview": "{\n \"fp16\": {\n \"enabled\": \"auto\",\n \"loss_scale\": 0,\n \"loss_scale_window\": 1000,\n \"initial_"
},
{
"path": "core/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "core/attention.py",
"chars": 5233,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
},
{
"path": "core/control.py",
"chars": 30423,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/diffuser_utils.py",
"chars": 11198,
"preview": "\"\"\"\nUtil functions based on Diffuser framework.\n\"\"\"\n\n\nimport os\nimport torch\nimport cv2\nimport numpy as np\n\nimport torch"
},
{
"path": "core/gs.py",
"chars": 7424,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diff_gaussian_rasterization"
},
{
"path": "core/masactrl.py",
"chars": 21224,
"preview": "import os\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom einops import rearrange\n\nfrom core.masa"
},
{
"path": "core/masactrl_utils.py",
"chars": 7907,
"preview": "import os\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom typing "
},
{
"path": "core/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "core/models/transformer_mv2d.py",
"chars": 46550,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/models/unet_mv2d_blocks.py",
"chars": 38426,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/models/unet_mv2d_condition.py",
"chars": 76507,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/models/unet_mv2d_condition_depth.py",
"chars": 76507,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/models/unet_mv2d_condition_depth_diffusion.py",
"chars": 77522,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/models/unet_mv2d_condition_depth_diffusion_test.py",
"chars": 77519,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "core/models_LGM_compos_diffusion.py",
"chars": 15738,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport kiui\nfrom kiui.lpips impor"
},
{
"path": "core/models_LGM_compos_diffusion_validate_inversion_2_masa.py",
"chars": 36691,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport kiui\nfrom kiui.lpips impor"
},
{
"path": "core/options_latents_diffusion.py",
"chars": 8269,
"preview": "import tyro\nfrom dataclasses import dataclass\nfrom typing import Tuple, Literal, Dict, Optional\n\n\n@dataclass\nclass Optio"
},
{
"path": "core/provider_Gobjaverse_latent_diffusion_insert.py",
"chars": 9928,
"preview": "import os\nimport cv2\nimport random\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as "
},
{
"path": "core/unet_LGM_compos.py",
"chars": 14110,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport numpy as np\nfrom typing import Tuple, Literal"
},
{
"path": "core/utils.py",
"chars": 3353,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport roma\nfrom kiui.op import "
},
{
"path": "infer_ours_masa.py",
"chars": 8278,
"preview": "\nimport os\nimport tyro\nimport glob\nimport imageio\nimport numpy as np\nimport tqdm\nimport torch\nimport torch.nn as nn\nimpo"
},
{
"path": "main_resume_compose.py",
"chars": 15109,
"preview": "import tyro\nimport time\nimport random\n\nimport torch\nfrom core.options_latents_diffusion import AllConfigs\nfrom core.mode"
},
{
"path": "mvdream/mv_unet.py",
"chars": 34403,
"preview": "import math\nimport numpy as np\nfrom inspect import isfunction\nfrom typing import Optional, Any, List\n\nimport torch\nimpor"
},
{
"path": "mvdream/pipeline_mvdream.py",
"chars": 24188,
"preview": "import torch\nimport torch.nn.functional as F\nimport inspect\nimport numpy as np\nfrom typing import Callable, List, Option"
}
]
About this extraction
This page contains the full source code of the PKU-YuanGroup/Cycle3D GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 34 files (638.0 KB), approximately 150.0k tokens, and a symbol index with 268 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.