Repository: PKU-YuanGroup/Cycle3D Branch: main Commit: 5ddc45691c8e Files: 34 Total size: 638.0 KB Directory structure: gitextract_hayw1hga/ ├── LICENSE ├── README.md ├── acc_configs/ │ ├── gpu1.yaml │ ├── gpu4.yaml │ ├── gpu6.yaml │ ├── gpu7.yaml │ ├── gpu8.yaml │ ├── hostfile │ ├── multi_node.yaml │ └── zero2.json ├── core/ │ ├── __init__.py │ ├── attention.py │ ├── control.py │ ├── diffuser_utils.py │ ├── gs.py │ ├── masactrl.py │ ├── masactrl_utils.py │ ├── models/ │ │ ├── __init__.py │ │ ├── transformer_mv2d.py │ │ ├── unet_mv2d_blocks.py │ │ ├── unet_mv2d_condition.py │ │ ├── unet_mv2d_condition_depth.py │ │ ├── unet_mv2d_condition_depth_diffusion.py │ │ └── unet_mv2d_condition_depth_diffusion_test.py │ ├── models_LGM_compos_diffusion.py │ ├── models_LGM_compos_diffusion_validate_inversion_2_masa.py │ ├── options_latents_diffusion.py │ ├── provider_Gobjaverse_latent_diffusion_insert.py │ ├── unet_LGM_compos.py │ └── utils.py ├── infer_ours_masa.py ├── main_resume_compose.py └── mvdream/ ├── mv_unet.py └── pipeline_mvdream.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 PKU-YUAN-Lab (袁粒课题组-北大信工) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

Cycle3D: High-quality and Consistent Image-to-3D Generation via Generation-Reconstruction Cycle

If you like our project, please give us a star ⭐ on GitHub for latest update.
[![webpage](https://img.shields.io/badge/Webpage-blue)](https://PKU-YuanGroup.github.io/Cycle3D/) [![arXiv](https://img.shields.io/badge/Arxiv-2407.19548-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2407.19548) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/PKU-YuanGroup/repaint123/blob/main/LICENSE)
## [Project page](https://PKU-YuanGroup.github.io/Cycle3D/) | [Paper](https://arxiv.org/abs/2407.19548) | [Live Demo (Coming Soon)]() ![image](https://github.com/user-attachments/assets/d6870ef6-4631-4fc2-a2dc-382c054afe0d) ## 😮 Highlights ### 🔥 Generation-Reconstruction cycle for the unified diffusion process - The pre-trained 2D diffusion model trained on billions of web images can generate high-quality texture. - The reconstruction model can ensure consistency across multi-views. - We cyclically utilizes a 2D diffusion-based generation module and a feed-forward 3D reconstruction module during the multi-step diffusion process. ## 🚩 **Updates** Welcome to **watch** 👀 this repository for the latest updates. ✅ **[2024.7.28]** : We have released our paper, Cycle3D on [arXiv](https://arxiv.org/abs/2407.19548). ✅ **[2024.7.28]** : Release [project page](https://PKU-YuanGroup.github.io/Cycle3D/). - [ ] Code release. - [ ] Online Demo. ## 🤗 Demo Coming soon! ## 🚀 Image-to-3D Results ### Qualitative comparison ![image](https://github.com/user-attachments/assets/ce4f0c0c-793b-4354-b3fa-7d30e97a8ddf) ### Quantitative comparison ![image](https://github.com/user-attachments/assets/25a9e1d2-124c-426d-a1a4-54a44aa7d0fc) ## 👍 **Acknowledgement** This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing! * [LGM](https://github.com/3DTopia/LGM) * [MasaCtrl](https://github.com/TencentARC/MasaCtrl) * [Diffusers](https://github.com/huggingface/diffusers) ## ✏️ Citation If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:. ```BibTeX @misc{tang2024cycle3dhighqualityconsistentimageto3d, title={Cycle3D: High-quality and Consistent Image-to-3D Generation via Generation-Reconstruction Cycle}, author={Zhenyu Tang and Junwu Zhang and Xinhua Cheng and Wangbo Yu and Chaoran Feng and Yatian Pang and Bin Lin and Li Yuan}, year={2024}, eprint={2407.19548}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2407.19548}, } ``` ================================================ FILE: acc_configs/gpu1.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false distributed_type: 'NO' downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 1 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false # distributed_type: DEEPSPEED # deepspeed_config: # gradient_clipping: 1.0 # zero_stage: 2 ================================================ FILE: acc_configs/gpu4.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: fp16 num_machines: 1 num_processes: 4 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: acc_configs/gpu6.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 6 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false distributed_type: DEEPSPEED deepspeed_config: gradient_clipping: 1.0 zero_stage: 2 ================================================ FILE: acc_configs/gpu7.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 7 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false distributed_type: DEEPSPEED deepspeed_config: gradient_clipping: 1.0 zero_stage: 2 ================================================ FILE: acc_configs/gpu8.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false distributed_type: MULTI_GPU downcast_bf16: 'no' machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false distributed_type: DEEPSPEED deepspeed_config: gradient_clipping: 1.0 zero_stage: 2 ================================================ FILE: acc_configs/hostfile ================================================ gpu147 slots=8 gpu176 slots=8 gpu47 slots=8 gpu117 slots=8 ================================================ FILE: acc_configs/multi_node.yaml ================================================ compute_environment: LOCAL_MACHINE distributed_type: DEEPSPEED deepspeed_config: gradient_clipping: 1.0 zero_stage: 2 deepspeed_hostfile: /remote-home1/yeyang/aigc/aigc/LGM/acc_configs/hostfile fsdp_config: {} machine_rank: 0 main_process_ip: 219.223.196.147 main_process_port: 29504 main_training_function: main num_machines: 4 num_processes: 32 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: acc_configs/zero2.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "train_micro_batch_size_per_gpu": "auto", "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "zero_optimization": { "stage": 2, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": "auto" } } ================================================ FILE: core/__init__.py ================================================ ================================================ FILE: core/attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import os import warnings from torch import Tensor from torch import nn XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: from xformers.ops import memory_efficient_attention, unbind XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (Attention)") else: warnings.warn("xFormers is disabled (Attention)") raise ImportError except ImportError: XFORMERS_AVAILABLE = False warnings.warn("xFormers is not available (Attention)") class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: Tensor) -> Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class CrossAttention(nn.Module): def __init__( self, dim: int, dim_q: int, dim_k: int, dim_v: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # q: [B, N, Cq] # k: [B, M, Ck] # v: [B, M, Cv] # return: [B, N, C] B, N, _ = q.shape M = k.shape[1] q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] attn = q @ k.transpose(-2, -1) # [B, nh, N, M] attn = attn.softmax(dim=-1) # [B, nh, N, M] attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] x = self.proj(x) x = self.proj_drop(x) return x class MemEffCrossAttention(CrossAttention): def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, _ = q.shape M = k.shape[1] q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x ================================================ FILE: core/control.py ================================================ # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import inspect from typing import Callable, List, Optional, Union, Dict, Any import torchvision.transforms.functional as TF import torch.nn.functional as F import torch import PIL from diffusers.utils import is_accelerate_available from packaging import version from tqdm import tqdm from transformers import ( CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, ) from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel, ControlNetModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils.torch_utils import is_compiled_module, is_torch_version from diffusers.pipelines import StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.image_processor import PipelineImageInput, VaeImageProcessor import kiui IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class ControlNetPipeline(StableDiffusionControlNetPipeline): def pred_x0( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta: float=0.0, verbose=False, ): """ predict the sampe the next step in the denoise process. """ alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x.device) alpha_prod_t = alphas_cumprod [timestep] B = alpha_prod_t.shape[0] alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1) beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 return pred_x0 def next_step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0., verbose=False ): """ Inverse sampling for DDIM Inversion """ if verbose: print("timestep: ", timestep) next_step = timestep timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir return x_next, pred_x0 @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], data = None, LGM_unet = None, opt = None, pos_act = None, scale_act = None, opacity_act = None, rot_act = None, rgb_act = None, gs = None, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeine class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ self.opt = opt self.pos_act = pos_act self.scale_act = scale_act self.opacity_act = opacity_act self.rot_act = rot_act self.rgb_act = rgb_act self.gs = gs callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct # self.check_inputs( # prompt, # image, # callback_steps, # negative_prompt, # prompt_embeds, # negative_prompt_embeds, # ip_adapter_image, # ip_adapter_image_embeds, # controlnet_conditioning_scale, # control_guidance_start, # control_guidance_end, # callback_on_step_end_tensor_inputs, # ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) if prompt_embeds is not None: prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] # Nested lists as ControlNet condition if isinstance(image[0], list): # Transpose the nested image list image = [list(t) for t in zip(*image)] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred, blocks_sample, tembpred_noise = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, ) if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) pred_x0 = self.pred_x0(noise_pred, timestep, latent_model_input) images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5 images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False) images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].flatten(0, 1).to(self.opt.weight_dtype) ], dim=1) # perform guidance # compute the previous noisy sample x_t -> x_t-1 #latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] has_nsfw_concept = None #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return images ================================================ FILE: core/diffuser_utils.py ================================================ """ Util functions based on Diffuser framework. """ import os import torch import cv2 import numpy as np import torch.nn.functional as F from tqdm import tqdm from PIL import Image from torchvision.utils import save_image from torchvision.io import read_image from diffusers import StableDiffusionPipeline from pytorch_lightning import seed_everything class MasaCtrlPipeline(StableDiffusionPipeline): def next_step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0., verbose=False ): """ Inverse sampling for DDIM Inversion """ if verbose: print("timestep: ", timestep) next_step = timestep timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir return x_next, pred_x0 def step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta: float=0.0, verbose=False, ): """ predict the sampe the next step in the denoise process. """ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir return x_prev, pred_x0 @torch.no_grad() def image2latent(self, image): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if type(image) is Image: image = np.array(image) image = torch.from_numpy(image).float() / 127.5 - 1 image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE) # input image density range [-1, 1] latents = self.vae.encode(image)['latent_dist'].mean latents = latents * 0.18215 return latents @torch.no_grad() def latent2image(self, latents, return_type='np'): latents = 1 / 0.18215 * latents.detach() image = self.vae.decode(latents)['sample'] if return_type == 'np': image = (image / 2 + 0.5).clamp(0, 1) image = image.to(torch.float).cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) elif return_type == "pt": image = (image / 2 + 0.5).clamp(0, 1) return image def latent2image_grad(self, latents): latents = 1 / 0.18215 * latents image = self.vae.decode(latents)['sample'] return image # range [-1, 1] @torch.no_grad() def __call__( self, prompt, batch_size=1, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, eta=0.0, latents=None, unconditioning=None, neg_prompt=None, ref_intermediate_latents=None, return_intermediates=False, **kwds): DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if isinstance(prompt, list): batch_size = len(prompt) elif isinstance(prompt, str): if batch_size > 1: prompt = [prompt] * batch_size # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] print("input text embeddings :", text_embeddings.shape) if kwds.get("dir"): dir = text_embeddings[-2] - text_embeddings[-1] u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True) text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v print(u.shape) print(v.shape) # define initial latents latents_shape = (batch_size, self.unet.in_channels, height//8, width//8) if latents is None: latents = torch.randn(latents_shape, device=DEVICE) else: assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one." # unconditional embedding for classifier free guidance if guidance_scale > 1.: max_length = text_input.input_ids.shape[-1] if neg_prompt: uc_text = neg_prompt else: uc_text = "" # uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face" unconditional_input = self.tokenizer( [uc_text] * batch_size, padding="max_length", max_length=77, return_tensors="pt" ) # unconditional_input.input_ids = unconditional_input.input_ids[:, 1:] unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0) print("latents shape: ", latents.shape) # iterative sampling self.scheduler.set_timesteps(num_inference_steps) # print("Valid timesteps: ", reversed(self.scheduler.timesteps)) latents_list = [latents] pred_x0_list = [latents] for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")): if ref_intermediate_latents is not None: # note that the batch_size >= 2 latents_ref = ref_intermediate_latents[-1 - i] _, latents_cur = latents.chunk(2) latents = torch.cat([latents_ref, latents_cur]) if guidance_scale > 1.: model_inputs = torch.cat([latents] * 2) else: model_inputs = latents if unconditioning is not None and isinstance(unconditioning, list): _, text_embeddings = text_embeddings.chunk(2) text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings]) # predict tghe noise noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample if guidance_scale > 1.: noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t -> x_t-1 latents, pred_x0 = self.step(noise_pred, t, latents) latents_list.append(latents) pred_x0_list.append(pred_x0) image = self.latent2image(latents, return_type="pt") if return_intermediates: pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list] latents_list = [self.latent2image(img, return_type="pt") for img in latents_list] return image, pred_x0_list, latents_list return image @torch.no_grad() def invert( self, image: torch.Tensor, prompt, num_inference_steps=50, guidance_scale=7.5, eta=0.0, return_intermediates=False, path = None, **kwds): """ invert a real image into noise map with determinisc DDIM inversion """ DEVICE = image.device batch_size = image.shape[0] if isinstance(prompt, list): if batch_size == 1: image = image.expand(len(prompt), -1, -1, -1) elif isinstance(prompt, str): if batch_size > 1: prompt = [prompt] * batch_size # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] print("input text embeddings :", text_embeddings.shape) # define initial latents latents = self.image2latent(image) start_latents = latents # print(latents) # exit() # unconditional embedding for classifier free guidance if guidance_scale > 1.: max_length = text_input.input_ids.shape[-1] unconditional_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, return_tensors="pt" ) unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0) print("latents shape: ", latents.shape) # interative sampling self.scheduler.set_timesteps(num_inference_steps) print("Valid timesteps: ", reversed(self.scheduler.timesteps)) # print("attributes: ", self.scheduler.__dict__) latents_list = [latents] pred_x0_list = [latents] for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): if guidance_scale > 1.: model_inputs = torch.cat([latents] * 2) else: model_inputs = latents # predict the noise noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample if guidance_scale > 1.: noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t-1 -> x_t latents, pred_x0 = self.next_step(noise_pred, t, latents) #Image.fromarray(self.latent2image(latents[:1])).save(os.path.join(path, str(i)+'_8.png')) # if kwds.get("workspace"): # Image.fromarray(self.latent2image(pred_x0[:1])).save(kwds.get("workspace")+'/'+str(i)+'_8.png') latents_list.append(latents) pred_x0_list.append(pred_x0) if return_intermediates: # return the intermediate laters during inversion # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list] return latents, latents_list return latents, start_latents ================================================ FILE: core/gs.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diff_gaussian_rasterization import ( GaussianRasterizationSettings, GaussianRasterizer, ) from core.options import Options import kiui class GaussianRenderer: def __init__(self, opt: Options): self.opt = opt self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") # intrinsics self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) self.proj_matrix[0, 0] = 1 / self.tan_half_fov self.proj_matrix[1, 1] = 1 / self.tan_half_fov self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) self.proj_matrix[2, 3] = 1 def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): # gaussians: [B, N, 14] # cam_view, cam_view_proj: [B, V, 4, 4] # cam_pos: [B, V, 3] device = gaussians.device B, V = cam_view.shape[:2] # loop of loop... images = [] alphas = [] for b in range(B): # pos, opacity, scale, rotation, shs means3D = gaussians[b, :, 0:3].contiguous().float() opacity = gaussians[b, :, 3:4].contiguous().float() scales = gaussians[b, :, 4:7].contiguous().float() rotations = gaussians[b, :, 7:11].contiguous().float() rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] for v in range(V): # render novel views view_matrix = cam_view[b, v].float() view_proj_matrix = cam_view_proj[b, v].float() campos = cam_pos[b, v].float() raster_settings = GaussianRasterizationSettings( image_height=self.opt.output_size, image_width=self.opt.output_size, tanfovx=self.tan_half_fov, tanfovy=self.tan_half_fov, bg=self.bg_color if bg_color is None else bg_color, scale_modifier=scale_modifier, viewmatrix=view_matrix, projmatrix=view_proj_matrix, sh_degree=0, campos=campos, prefiltered=False, debug=False, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) # Rasterize visible Gaussians to image, obtain their radii (on screen). rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( means3D=means3D, means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), shs=None, colors_precomp=rgbs, opacities=opacity, scales=scales, rotations=rotations, cov3D_precomp=None, ) rendered_image = rendered_image.clamp(0, 1) images.append(rendered_image) alphas.append(rendered_alpha) images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) return { "image": images, # [B, V, 3, H, W] "alpha": alphas, # [B, V, 1, H, W] } def save_ply(self, gaussians, path, compatible=True): # gaussians: [B, N, 14] # compatible: save pre-activated gaussians as in the original paper assert gaussians.shape[0] == 1, 'only support batch size 1' from plyfile import PlyData, PlyElement means3D = gaussians[0, :, 0:3].contiguous().float() opacity = gaussians[0, :, 3:4].contiguous().float() scales = gaussians[0, :, 4:7].contiguous().float() rotations = gaussians[0, :, 7:11].contiguous().float() shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] # prune by opacity mask = opacity.squeeze(-1) >= 0.005 means3D = means3D[mask] opacity = opacity[mask] scales = scales[mask] rotations = rotations[mask] shs = shs[mask] # invert activation to make it compatible with the original ply format if compatible: opacity = kiui.op.inverse_sigmoid(opacity) scales = torch.log(scales + 1e-8) shs = (shs - 0.5) / 0.28209479177387814 xyzs = means3D.detach().cpu().numpy() f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() opacities = opacity.detach().cpu().numpy() scales = scales.detach().cpu().numpy() rotations = rotations.detach().cpu().numpy() l = ['x', 'y', 'z'] # All channels except the 3 DC for i in range(f_dc.shape[1]): l.append('f_dc_{}'.format(i)) l.append('opacity') for i in range(scales.shape[1]): l.append('scale_{}'.format(i)) for i in range(rotations.shape[1]): l.append('rot_{}'.format(i)) dtype_full = [(attribute, 'f4') for attribute in l] elements = np.empty(xyzs.shape[0], dtype=dtype_full) attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) elements[:] = list(map(tuple, attributes)) el = PlyElement.describe(elements, 'vertex') PlyData([el]).write(path) def load_ply(self, path, compatible=True): from plyfile import PlyData, PlyElement plydata = PlyData.read(path) xyz = np.stack((np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), np.asarray(plydata.elements[0]["z"])), axis=1) print("Number of points at loading : ", xyz.shape[0]) opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] shs = np.zeros((xyz.shape[0], 3)) shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] scales = np.zeros((xyz.shape[0], len(scale_names))) for idx, attr_name in enumerate(scale_names): scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] rots = np.zeros((xyz.shape[0], len(rot_names))) for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) gaussians = torch.from_numpy(gaussians).float() # cpu if compatible: gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 return gaussians ================================================ FILE: core/masactrl.py ================================================ import os import torch import torch.nn.functional as F import numpy as np from einops import rearrange from core.masactrl_utils import AttentionBase from torchvision.utils import save_image class MutualSelfAttentionControl(AttentionBase): MODEL_TYPE = { "SD": 16, "SDXL": 70 } def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"): """ Mutual self-attention control for Stable-Diffusion model Args: start_step: the step to start mutual self-attention control start_layer: the layer to start mutual self-attention control layer_idx: list of the layers to apply mutual self-attention control step_idx: list the steps to apply mutual self-attention control total_steps: the total number of steps model_type: the model type, SD or SDXL """ super().__init__() self.total_steps = total_steps self.total_layers = self.MODEL_TYPE.get(model_type, 16) self.start_step = start_step self.start_layer = start_layer self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) print("MasaCtrl at denoising steps: ", self.step_idx) print("MasaCtrl at U-Net layers: ", self.layer_idx) def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Performing attention for a batch of queries, keys, and values """ b = q.shape[0] // num_heads q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") attn = sim.softmax(-1) out = torch.einsum("h i j, h j d -> h i d", attn, v) out = rearrange(out, "h (b n) d -> b n (h d)", b=b) return out def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Attention forward function """ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) qu, qc = q.chunk(2) ku, kc = k.chunk(2) vu, vc = v.chunk(2) attnu, attnc = attn.chunk(2) out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) out = torch.cat([out_u, out_c], dim=0) return out class MutualSelfAttention3DControl(AttentionBase): MODEL_TYPE = { "SD": 16, "SDXL": 70 } def __init__(self, start_steps=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"): """ Mutual self-attention control for Stable-Diffusion model Args: start_step: the step to start mutual self-attention control start_layer: the layer to start mutual self-attention control layer_idx: list of the layers to apply mutual self-attention control step_idx: list the steps to apply mutual self-attention control total_steps: the total number of steps model_type: the model type, SD or SDXL """ super().__init__() self.total_steps = total_steps self.total_layers = self.MODEL_TYPE.get(model_type, 16) self.start_step = start_steps self.start_layer = start_layer self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) self.step_idx = step_idx if step_idx is not None else list(range(start_steps, total_steps)) print("MasaCtrl at denoising steps: ", self.step_idx) print("MasaCtrl at U-Net layers: ", self.layer_idx) def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Performing attention for a batch of queries, keys, and values """ b = q.shape[0] // num_heads q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") attn = sim.softmax(-1) out = torch.einsum("h i j, h j d -> h i d", attn, v) out = rearrange(out, "h (b n) d -> b n (h d)", b=b) return out def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Attention forward function """ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) # qu, qc = q.chunk(2) # ku, kc = k.chunk(2) # vu, vc = v.chunk(2) # attnu, attnc = attn.chunk(2) q_t1, q_t2, q_t3, q_t4, q_s= q.chunk(5) k_t1, k_t2, k_t3, k_t4, k_s = k.chunk(5) v_t1, v_t2, v_t3, v_t4, v_s= v.chunk(5) attn_t1, attn_t2, attn_t3, attn_t4, attn_s= attn.chunk(5) out_s = super().forward(q_s, k_s, v_s, sim, attn_s, is_cross, place_in_unet, num_heads, **kwargs) out_t1 = self.attn_batch(q_t1, torch.cat([k_s, k_t1]), torch.cat([v_s, v_t1]), sim[:num_heads], attn_t1, is_cross, place_in_unet, num_heads, **kwargs) out_t2 = self.attn_batch(q_t2, torch.cat([k_s, k_t2]), torch.cat([v_s, v_t2]), sim[:num_heads], attn_t2, is_cross, place_in_unet, num_heads, **kwargs) out_t3 = self.attn_batch(q_t3, torch.cat([k_s, k_t3]), torch.cat([v_s, v_t3]), sim[:num_heads], attn_t3, is_cross, place_in_unet, num_heads, **kwargs) out_t4 = self.attn_batch(q_t4, torch.cat([k_s, k_t4]), torch.cat([v_s, v_t4]), sim[:num_heads], attn_t4, is_cross, place_in_unet, num_heads, **kwargs) print(1) # out_t1 = self.attn_batch(q_t1, k_s, v_s, sim[:num_heads], attn_t1, is_cross, place_in_unet, num_heads, **kwargs) # out_t2 = self.attn_batch(q_t2, k_s, v_s, sim[:num_heads], attn_t2, is_cross, place_in_unet, num_heads, **kwargs) # out_t3 = self.attn_batch(q_t3, k_s, v_s, sim[:num_heads], attn_t3, is_cross, place_in_unet, num_heads, **kwargs) # out_t4 = self.attn_batch(q_t4, k_s, v_s, sim[:num_heads], attn_t4, is_cross, place_in_unet, num_heads, **kwargs) # out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) # out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) out = torch.cat([out_t1, out_t2, out_t3, out_t4, out_s], dim=0) return out class MutualSelfAttentionControlUnion(MutualSelfAttentionControl): def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"): """ Mutual self-attention control for Stable-Diffusion model with unition source and target [K, V] Args: start_step: the step to start mutual self-attention control start_layer: the layer to start mutual self-attention control layer_idx: list of the layers to apply mutual self-attention control step_idx: list the steps to apply mutual self-attention control total_steps: the total number of steps model_type: the model type, SD or SDXL """ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type) def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Attention forward function """ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) qu_s, qu_t, qc_s, qc_t = q.chunk(4) ku_s, ku_t, kc_s, kc_t = k.chunk(4) vu_s, vu_t, vc_s, vc_t = v.chunk(4) attnu_s, attnu_t, attnc_s, attnc_t = attn.chunk(4) # source image branch out_u_s = super().forward(qu_s, ku_s, vu_s, sim, attnu_s, is_cross, place_in_unet, num_heads, **kwargs) out_c_s = super().forward(qc_s, kc_s, vc_s, sim, attnc_s, is_cross, place_in_unet, num_heads, **kwargs) # target image branch, concatenating source and target [K, V] out_u_t = self.attn_batch(qu_t, torch.cat([ku_s, ku_t]), torch.cat([vu_s, vu_t]), sim[:num_heads], attnu_t, is_cross, place_in_unet, num_heads, **kwargs) out_c_t = self.attn_batch(qc_t, torch.cat([kc_s, kc_t]), torch.cat([vc_s, vc_t]), sim[:num_heads], attnc_t, is_cross, place_in_unet, num_heads, **kwargs) out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0) return out class MutualSelfAttentionControlMask(MutualSelfAttentionControl): def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None, model_type="SD"): """ Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion Args: start_step: the step to start mutual self-attention control start_layer: the layer to start mutual self-attention control layer_idx: list of the layers to apply mutual self-attention control step_idx: list the steps to apply mutual self-attention control total_steps: the total number of steps mask_s: source mask with shape (h, w) mask_t: target mask with same shape as source mask mask_save_dir: the path to save the mask image model_type: the model type, SD or SDXL """ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type) self.mask_s = mask_s # source mask with shape (h, w) self.mask_t = mask_t # target mask with same shape as source mask print("Using mask-guided MasaCtrl") if mask_save_dir is not None: os.makedirs(mask_save_dir, exist_ok=True) save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_s.png")) save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_t.png")) def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): B = q.shape[0] // num_heads H = W = int(np.sqrt(q.shape[1])) q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") if kwargs.get("is_mask_attn") and self.mask_s is not None: print("masked attention") mask = self.mask_s.unsqueeze(0).unsqueeze(0) mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0) mask = mask.flatten() # background sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min) # object sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min) sim = torch.cat([sim_fg, sim_bg], dim=0) attn = sim.softmax(-1) if len(attn) == 2 * len(v): v = torch.cat([v] * 2) out = torch.einsum("h i j, h j d -> h i d", attn, v) out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) return out def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Attention forward function """ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) B = q.shape[0] // num_heads // 2 H = W = int(np.sqrt(q.shape[1])) qu, qc = q.chunk(2) ku, kc = k.chunk(2) vu, vc = v.chunk(2) attnu, attnc = attn.chunk(2) out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs) out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs) if self.mask_s is not None and self.mask_t is not None: out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0) out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0) mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W)) mask = mask.reshape(-1, 1) # (hw, 1) out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask) out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask) out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0) return out class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl): def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None, model_type="SD"): """ MasaCtrl with mask auto generation from cross-attention map Args: start_step: the step to start mutual self-attention control start_layer: the layer to start mutual self-attention control layer_idx: list of the layers to apply mutual self-attention control step_idx: list the steps to apply mutual self-attention control total_steps: the total number of steps thres: the thereshold for mask thresholding ref_token_idx: the token index list for cross-attention map aggregation cur_token_idx: the token index list for cross-attention map aggregation mask_save_dir: the path to save the mask image """ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type) print("Using MutualSelfAttentionControlMaskAuto") self.thres = thres self.ref_token_idx = ref_token_idx self.cur_token_idx = cur_token_idx self.self_attns = [] self.cross_attns = [] self.cross_attns_mask = None self.self_attns_mask = None self.mask_save_dir = mask_save_dir if self.mask_save_dir is not None: os.makedirs(self.mask_save_dir, exist_ok=True) def after_step(self): self.self_attns = [] self.cross_attns = [] def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Performing attention for a batch of queries, keys, and values """ B = q.shape[0] // num_heads H = W = int(np.sqrt(q.shape[1])) q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") if self.self_attns_mask is not None: # binarize the mask mask = self.self_attns_mask thres = self.thres mask[mask >= thres] = 1 mask[mask < thres] = 0 sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min) sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min) sim = torch.cat([sim_fg, sim_bg]) attn = sim.softmax(-1) if len(attn) == 2 * len(v): v = torch.cat([v] * 2) out = torch.einsum("h i j, h j d -> h i d", attn, v) out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) return out def aggregate_cross_attn_map(self, idx): attn_map = torch.stack(self.cross_attns, dim=1).mean(1) # (B, N, dim) B = attn_map.shape[0] res = int(np.sqrt(attn_map.shape[-2])) attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1]) image = attn_map[..., idx] if isinstance(idx, list): image = image.sum(-1) image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0] image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0] image = (image - image_min) / (image_max - image_min) return image def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): """ Attention forward function """ if is_cross: # save cross attention map with res 16 * 16 if attn.shape[1] == 16 * 16: self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1)) if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) B = q.shape[0] // num_heads // 2 H = W = int(np.sqrt(q.shape[1])) qu, qc = q.chunk(2) ku, kc = k.chunk(2) vu, vc = v.chunk(2) attnu, attnc = attn.chunk(2) out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) if len(self.cross_attns) == 0: self.self_attns_mask = None out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) else: mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx) # (2, H, W) mask_source = mask[-2] # (H, W) res = int(np.sqrt(q.shape[1])) self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten() if self.mask_save_dir is not None: H = W = int(np.sqrt(self.self_attns_mask.shape[0])) mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0) save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_s_{self.cur_step}_{self.cur_att_layer}.png")) out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) if self.self_attns_mask is not None: mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (2, H, W) mask_target = mask[-1] # (H, W) res = int(np.sqrt(q.shape[1])) spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1) if self.mask_save_dir is not None: H = W = int(np.sqrt(spatial_mask.shape[0])) mask_image = spatial_mask.reshape(H, W).unsqueeze(0) save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_t_{self.cur_step}_{self.cur_att_layer}.png")) # binarize the mask thres = self.thres spatial_mask[spatial_mask >= thres] = 1 spatial_mask[spatial_mask < thres] = 0 out_u_target_fg, out_u_target_bg = out_u_target.chunk(2) out_c_target_fg, out_c_target_bg = out_c_target.chunk(2) out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask) out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask) # set self self-attention mask to None self.self_attns_mask = None out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0) return out ================================================ FILE: core/masactrl_utils.py ================================================ import os import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Union, Tuple, List, Callable, Dict from torchvision.utils import save_image from einops import rearrange, repeat class AttentionBase: def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 def after_step(self): pass def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.cur_step += 1 # after step self.after_step() return out def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): out = torch.einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads) return out def reset(self): self.cur_step = 0 self.cur_att_layer = 0 class AttentionStore(AttentionBase): def __init__(self, res=[32], min_step=0, max_step=1000): super().__init__() self.res = res self.min_step = min_step self.max_step = max_step self.valid_steps = 0 self.self_attns = [] # store the all attns self.cross_attns = [] self.self_attns_step = [] # store the attns in each step self.cross_attns_step = [] def after_step(self): if self.cur_step > self.min_step and self.cur_step < self.max_step: self.valid_steps += 1 if len(self.self_attns) == 0: self.self_attns = self.self_attns_step self.cross_attns = self.cross_attns_step else: for i in range(len(self.self_attns)): self.self_attns[i] += self.self_attns_step[i] self.cross_attns[i] += self.cross_attns_step[i] self.self_attns_step.clear() self.cross_attns_step.clear() def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): if attn.shape[1] <= 64 ** 2: # avoid OOM if is_cross: self.cross_attns_step.append(attn) else: self.self_attns_step.append(attn) return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) def regiter_attention_editor_diffusers(unet, editor: AttentionBase): """ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] """ def ca_forward(self, place_in_unet): def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): """ The attention is similar to the original implementation of LDM CrossAttention class except adding some modifications on the attention """ if encoder_hidden_states is not None: context = encoder_hidden_states if attention_mask is not None: mask = attention_mask to_out = self.to_out if isinstance(to_out, nn.modules.container.ModuleList): to_out = self.to_out[0] else: to_out = self.to_out h = self.heads q = self.to_q(x) is_cross = context is not None context = context if is_cross else x k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale if mask is not None: mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim=-1) # the only difference out = editor( q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale) return to_out(out) return forward def register_editor(net, count, place_in_unet): for name, subnet in net.named_children(): if net.__class__.__name__ == 'Attention': # spatial Transformer layer net.forward = ca_forward(net, place_in_unet) return count + 1 elif hasattr(net, 'children'): count = register_editor(subnet, count, place_in_unet) return count cross_att_count = 0 for net_name, net in unet.named_children(): if "down" in net_name: cross_att_count += register_editor(net, 0, "down") elif "mid" in net_name: cross_att_count += register_editor(net, 0, "mid") elif "up" in net_name: cross_att_count += register_editor(net, 0, "up") editor.num_att_layers = cross_att_count def regiter_attention_editor_ldm(model, editor: AttentionBase): """ Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt] """ def ca_forward(self, place_in_unet): def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): """ The attention is similar to the original implementation of LDM CrossAttention class except adding some modifications on the attention """ if encoder_hidden_states is not None: context = encoder_hidden_states if attention_mask is not None: mask = attention_mask to_out = self.to_out if isinstance(to_out, nn.modules.container.ModuleList): to_out = self.to_out[0] else: to_out = self.to_out h = self.heads q = self.to_q(x) is_cross = context is not None context = context if is_cross else x k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale if mask is not None: mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim=-1) # the only difference out = editor( q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale) return to_out(out) return forward def register_editor(net, count, place_in_unet): for name, subnet in net.named_children(): if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer net.forward = ca_forward(net, place_in_unet) return count + 1 elif hasattr(net, 'children'): count = register_editor(subnet, count, place_in_unet) return count cross_att_count = 0 for net_name, net in model.model.diffusion_model.named_children(): if "input" in net_name: cross_att_count += register_editor(net, 0, "input") elif "middle" in net_name: cross_att_count += register_editor(net, 0, "middle") elif "output" in net_name: cross_att_count += register_editor(net, 0, "output") editor.num_att_layers = cross_att_count ================================================ FILE: core/models/__init__.py ================================================ ================================================ FILE: core/models/transformer_mv2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.embeddings import ImagePositionalEmbeddings from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention from diffusers.models.embeddings import PatchEmbed from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear from diffusers.models.modeling_utils import ModelMixin from diffusers.utils.import_utils import is_xformers_available from einops import rearrange, repeat import pdb import random def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 0: return nn.Linear(*args, **kwargs) if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") if is_xformers_available(): import xformers import xformers.ops else: xformers = None def my_repeat(tensor, num_repeats): """ Repeat a tensor along a given dimension """ if len(tensor.shape) == 3: return repeat(tensor, "b d c -> (b v) d c", v=num_repeats) elif len(tensor.shape) == 4: return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats) @dataclass class TransformerMV2DModelOutput(BaseOutput): """ The output of [`Transformer2DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: torch.FloatTensor class TransformerMV2DModel(ModelMixin, ConfigMixin): """ A 2D Transformer model for image-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. num_vector_embeds (`int`, *optional*): The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, num_views: int = 1, cd_attention_last: bool=False, cd_attention_mid: bool=False, multiview_attention: bool=True, sparse_mv_attention: bool = False, mvcd_attention: bool=False ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None if norm_type == "layer_norm" and num_embeds_ada_norm is not None: deprecation_message = ( f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" " would be very nice if you could open a Pull request for the `transformer/config.json` file" ) deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) norm_type = "ada_norm" if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif self.is_input_vectorized and self.is_input_patches: raise ValueError( f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" " sure that either `num_vector_embeds` or `num_patches` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: raise ValueError( f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) # 2. Define input layers if self.is_input_continuous: self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) else: self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = sample_size self.width = sample_size self.num_vector_embeds = num_vector_embeds self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) elif self.is_input_patches: assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" self.height = sample_size self.width = sample_size self.patch_size = patch_size self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicMVTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) for d in range(num_layers) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) else: self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) elif self.is_input_patches: self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) self.post_init() def post_init(self): conv_block = self.proj_in conv_params = { k: getattr(conv_block, k) for k in [ "in_channels", "out_channels", "kernel_size", "stride", "padding", ] } conv_params["in_channels"] += 6 conv_params["dims"] = 2 conv_params["device"] = conv_block.weight.device inflated_proj_in = conv_nd(**conv_params) inp_weight = conv_block.weight.data feat_shape = inp_weight.shape feat_weight = torch.zeros( (feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device ) inflated_proj_in.weight.data.copy_( torch.cat([inp_weight, feat_weight], dim=1) ) inflated_proj_in.bias.data.copy_(conv_block.bias.data) self.proj_in = inflated_proj_in self.post_intialized = True def post_linear_init(self): linear_block = self.proj_in linear_params = { k: getattr(linear_block, k) for k in [ "in_features", "out_features" ] } linear_params["in_features"] += 6 linear_params["dims"] = 0 linear_params["device"] = linear_block.weight.device inflated_proj_in = conv_nd(**linear_params) inp_weight = linear_block.weight.data feat_shape = inp_weight.shape feat_weight = torch.zeros( (feat_shape[0], 6), device=inp_weight.device ) inflated_proj_in.weight.data.copy_( torch.cat([inp_weight, feat_weight], dim=1) ) inflated_proj_in.bias.data.copy_(linear_block.bias.data) self.proj_in = inflated_proj_in self.post_intialized = True def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ray_embedding: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if self.post_intialized: #ray_embedding = rearrange(ray_embedding, "n v c h w -> (n v) c h w") ray_embedding_interpolated = F.interpolate(ray_embedding, size=hidden_states.shape[-2:], align_corners=False, mode="bilinear") #ray_embedding_interpolated = rearrange(ray_embedding_interpolated, "(n v) c h w -> n v c h w", v=4) # concat plucker to x hidden_states = torch.cat([hidden_states, ray_embedding_interpolated], dim=1) #hidden_states = rearrange(hidden_states, "n v c h w -> (n v) c h w") # x = self.proj_in(x) # x = rearrange(x, "(n v) c h w -> n v c h w", v=4) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) inner_dim = inner_dim -6 elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() elif self.is_input_patches: # TODO: cleanup! conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) if not return_dict: return (output,) return TransformerMV2DModelOutput(sample=output) @maybe_allow_in_graph class BasicMVTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.multiview_attention = multiview_attention self.sparse_mv_attention = sparse_mv_attention self.mvcd_attention = mvcd_attention self.attn1 = CustomAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, processor=MVAttnProcessor() ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 self.num_views = num_views self.cd_attention_last = cd_attention_last if self.cd_attention_last: # Joint task -Attn self.attn_joint_last = CustomJointAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, processor=JointAttnProcessor() ) nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data) self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) self.cd_attention_mid = cd_attention_mid if self.cd_attention_mid: # print("cross-domain attn in the middle") # Joint task -Attn self.attn_joint_mid = CustomJointAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, processor=JointAttnProcessor() ) nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data) self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): assert attention_mask is None # not supported yet # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, num_views=self.num_views, multiview_attention=self.multiview_attention, sparse_mv_attention=self.sparse_mv_attention, mvcd_attention=self.mvcd_attention, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states # joint attention twice if self.cd_attention_mid: norm_hidden_states = ( self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states) ) hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim, ) else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states if self.cd_attention_last: norm_hidden_states = ( self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states) ) hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states return hidden_states class CustomAttention(Attention): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, *args, **kwargs ): processor = XFormersMVAttnProcessor() self.set_processor(processor) # print("using xformers attention processor") class CustomJointAttention(Attention): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, *args, **kwargs ): processor = XFormersJointAttnProcessor() self.set_processor(processor) # print("using xformers attention processor") class MVAttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_views=1, multiview_attention=True ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # print('query', query.shape, 'key', key.shape, 'value', value.shape) #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) # pdb.set_trace() # multi-view self-attention if multiview_attention: key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class XFormersMVAttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_views=1., multiview_attention=True, sparse_mv_attention=False, mvcd_attention=False, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # from yuancheng; here attention_mask is None if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key_raw = attn.to_k(encoder_hidden_states) value_raw = attn.to_v(encoder_hidden_states) # print('query', query.shape, 'key', key.shape, 'value', value.shape) #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) # pdb.set_trace() # multi-view self-attention if multiview_attention: if not sparse_mv_attention: key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views) else: key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c] value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c value = torch.cat([value_front, value_raw], dim=1) else: # print("don't use multiview attention.") key = key_raw value = value_raw query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class XFormersJointAttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_tasks=2 ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # from yuancheng; here attention_mask is None if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) assert num_tasks == 2 # only support two tasks now key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c value_0, value_1 = torch.chunk(value, dim=0, chunks=2) key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c value = torch.cat([value]*2, dim=0) # (2 b t) 2d c query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class JointAttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_tasks=2 ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) assert num_tasks == 2 # only support two tasks now key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c value_0, value_1 = torch.chunk(value, dim=0, chunks=2) key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c value = torch.cat([value]*2, dim=0) # (2 b t) 2d c query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states ================================================ FILE: core/models/unet_mv2d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from diffusers.utils import is_torch_version, logging from diffusers.models.attention import AdaGroupNorm from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from diffusers.models.dual_transformer_2d import DualTransformer2DModel from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_mv2d import TransformerMV2DModel from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D logger = logging.get_logger(__name__) # pylint: disable=invalid-name class IdentityMLP(nn.Module): def __init__(self, size): super(IdentityMLP, self).__init__() self.linear = nn.Linear(size, size) self.init_identity() def forward(self, x): return self.linear(x) def init_identity(self): # Initialize the weights to an identity matrix and biases to zero identity_matrix = torch.eye(self.linear.in_features) self.linear.weight.data.copy_(identity_matrix) self.linear.bias.data.zero_() def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, downsample_type=None, num_views=1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool=False ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "ResnetDownsampleBlock2D": return ResnetDownsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": if add_downsample is False: downsample_type = None else: downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) # custom MV2D attention block elif down_block_type == "CrossAttnDownBlockMV2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") return CrossAttnDownBlockMV2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") return SimpleCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif down_block_type == "KCrossAttnDownBlock2D": return KCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, add_self_attention=True if not add_downsample else False, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, upsample_type=None, num_views=1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool=False ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "ResnetUpsampleBlock2D": return ResnetUpsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) # custom MV2D attention block elif up_block_type == "CrossAttnUpBlockMV2D": # if cross_attention_dim is None: # raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") return CrossAttnUpBlockMV2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") return SimpleCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: upsample_type = None else: upsample_type = upsample_type or "conv" # default to 'conv' return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif up_block_type == "KCrossAttnUpBlock2D": return KCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlockMV2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool=False ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if not dual_cross_attention: attentions.append( TransformerMV2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) ) else: raise NotImplementedError resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ray_embedding: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ray_embedding=ray_embedding, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states class CrossAttnUpBlockMV2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool=False ): super().__init__() resnets = [] attentions = [] mlps = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels mlps.append(IdentityMLP(res_skip_channels)) resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( TransformerMV2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) ) else: raise NotImplementedError self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.mlps = nn.ModuleList(mlps) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ray_embedding: Optional[torch.Tensor] = None, ): for resnet, attn, mlp in zip(self.resnets, self.attentions, self.mlps): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] B, _, H, W = res_hidden_states.shape res_hidden_states = res_hidden_states.permute(0, 2, 3, 1).reshape(B, H * W, _) res_hidden_states = mlp(res_hidden_states) res_hidden_states = res_hidden_states.reshape(B, H, W, _).permute(0, 3, 1, 2).contiguous() hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, ray_embedding, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ray_embedding = ray_embedding, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class CrossAttnDownBlockMV2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool=False ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( TransformerMV2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) ) else: raise NotImplementedError self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, additional_residuals=None, ): output_states = () blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states ================================================ FILE: core/models/unet_mv2d_condition.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import os import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, ) from diffusers.utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, logging, ) from diffusers import __version__ from .unet_mv2d_blocks import ( CrossAttnDownBlockMV2D, CrossAttnUpBlockMV2D, UNetMidBlockMV2DCrossAttn, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNetMV2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # custom MV2D attention block elif mid_block_type == "UNetMidBlockMV2DCrossAttn": self.mid_block = UNetMidBlockMV2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up ### FIXME #up_cross_attention_dim = (None, None, None, None) reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNetMV2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNetMV2DConditionOutput(sample=sample) @classmethod def from_pretrained_2d( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 4, sample_size: int = 64, zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False, in_channels: int = 10, out_channels: int = 15, **kwargs ): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) # modify config config["_class_name"] = cls.__name__ config['in_channels'] = in_channels config['out_channels'] = out_channels config['sample_size'] = sample_size # training resolution config['num_views'] = num_views config['cd_attention_last'] = cd_attention_last config['cd_attention_mid'] = cd_attention_mid config['multiview_attention'] = multiview_attention config['sparse_mv_attention'] = sparse_mv_attention config['mvcd_attention'] = mvcd_attention config["down_block_types"] = [ "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "DownBlock2D" ] config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" config["up_block_types"] = [ "UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D" ] #config['class_embed_type'] = 'projection' if camera_embedding_type == 'e_de_da_sincos': config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 else: raise NotImplementedError # load model model_file = None if from_flax: raise NotImplementedError else: if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: if not allow_pickle: raise e pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) import copy state_dict_v0 = load_state_dict(model_file, variant=variant) state_dict = copy.deepcopy(state_dict_v0) # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid for key in state_dict_v0: if 'attn_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp) if 'norm_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp) if 'attn_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp) if 'norm_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp) model._convert_deprecated_attention_blocks(state_dict) conv_in_weight = state_dict['conv_in.weight'] conv_out_weight = state_dict['conv_out.weight'] model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=True, ) if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): # initialize from the original SD structure model.conv_in.weight.data[:,:4] = conv_in_weight # whether to place all zero to new layers? if zero_init_conv_in: model.conv_in.weight.data[:,4:] = 0. if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): # initialize from the original SD structure model.conv_out.weight.data[-4:, ] = conv_out_weight # model.conv_out.weight.data[:,:4] = conv_out_weight # if out_channels == 8: # copy for the last 4 channels # model.conv_out.weight.data[:, 4:] = conv_out_weight if zero_init_camera_projection: for p in model.class_embedding.parameters(): torch.nn.init.zeros_(p) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model_2d( cls, model, state_dict, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs ================================================ FILE: core/models/unet_mv2d_condition_depth.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import os import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, ) from diffusers.utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, logging, ) from diffusers import __version__ from .unet_mv2d_blocks import ( CrossAttnDownBlockMV2D, CrossAttnUpBlockMV2D, UNetMidBlockMV2DCrossAttn, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNetMV2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # custom MV2D attention block elif mid_block_type == "UNetMidBlockMV2DCrossAttn": self.mid_block = UNetMidBlockMV2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up ### FIXME #up_cross_attention_dim = (None, None, None, None) reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNetMV2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNetMV2DConditionOutput(sample=sample) @classmethod def from_pretrained_2d( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 4, sample_size: int = 64, zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False, in_channels: int = 10, out_channels: int = 13, **kwargs ): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) # modify config config["_class_name"] = cls.__name__ config['in_channels'] = in_channels config['out_channels'] = out_channels config['sample_size'] = sample_size # training resolution config['num_views'] = num_views config['cd_attention_last'] = cd_attention_last config['cd_attention_mid'] = cd_attention_mid config['multiview_attention'] = multiview_attention config['sparse_mv_attention'] = sparse_mv_attention config['mvcd_attention'] = mvcd_attention config["down_block_types"] = [ "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "DownBlock2D" ] config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" config["up_block_types"] = [ "UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D" ] #config['class_embed_type'] = 'projection' if camera_embedding_type == 'e_de_da_sincos': config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 else: raise NotImplementedError # load model model_file = None if from_flax: raise NotImplementedError else: if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: if not allow_pickle: raise e pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) import copy state_dict_v0 = load_state_dict(model_file, variant=variant) state_dict = copy.deepcopy(state_dict_v0) # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid for key in state_dict_v0: if 'attn_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp) if 'norm_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp) if 'attn_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp) if 'norm_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp) model._convert_deprecated_attention_blocks(state_dict) conv_in_weight = state_dict['conv_in.weight'] conv_out_weight = state_dict['conv_out.weight'] model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=True, ) if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): # initialize from the original SD structure model.conv_in.weight.data[:,:4] = conv_in_weight # whether to place all zero to new layers? if zero_init_conv_in: model.conv_in.weight.data[:,4:] = 0. if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): # initialize from the original SD structure model.conv_out.weight.data[-4:, ] = conv_out_weight # model.conv_out.weight.data[:,:4] = conv_out_weight # if out_channels == 8: # copy for the last 4 channels # model.conv_out.weight.data[:, 4:] = conv_out_weight if zero_init_camera_projection: for p in model.class_embedding.parameters(): torch.nn.init.zeros_(p) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model_2d( cls, model, state_dict, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs ================================================ FILE: core/models/unet_mv2d_condition_depth_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import os import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, ) from diffusers.utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, logging, ) from diffusers import __version__ from .unet_mv2d_blocks import ( CrossAttnDownBlockMV2D, CrossAttnUpBlockMV2D, UNetMidBlockMV2DCrossAttn, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class IdentityMLP(nn.Module): def __init__(self, size): super(IdentityMLP, self).__init__() self.linear = nn.Linear(size, size) self.init_identity() def forward(self, x): return self.linear(x) def init_identity(self): # Initialize the weights to an identity matrix and biases to zero identity_matrix = torch.eye(self.linear.in_features) self.linear.weight.data.copy_(identity_matrix) self.linear.bias.data.zero_() @dataclass class UNetMV2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # custom MV2D attention block elif mid_block_type == "UNetMidBlockMV2DCrossAttn": self.mid_block = UNetMidBlockMV2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up ### FIXME #up_cross_attention_dim = (None, None, None, None) reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ray_embedding: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNetMV2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ray_embedding = ray_embedding, ) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ray_embedding = ray_embedding, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNetMV2DConditionOutput(sample=sample) @classmethod def from_pretrained_2d( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 4, sample_size: int = 64, zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False, in_channels: int = 4, out_channels: int = 13, **kwargs ): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) # modify config config["_class_name"] = cls.__name__ config['in_channels'] = in_channels config['out_channels'] = out_channels config['sample_size'] = sample_size # training resolution config['num_views'] = num_views config['cd_attention_last'] = cd_attention_last config['cd_attention_mid'] = cd_attention_mid config['multiview_attention'] = multiview_attention config['sparse_mv_attention'] = sparse_mv_attention config['mvcd_attention'] = mvcd_attention config["down_block_types"] = [ "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D" ] config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" config["up_block_types"] = [ "UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D" ] #config['class_embed_type'] = 'projection' if camera_embedding_type == 'e_de_da_sincos': config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 else: raise NotImplementedError # load model model_file = None if from_flax: raise NotImplementedError else: if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: if not allow_pickle: raise e pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) import copy state_dict_v0 = load_state_dict(model_file, variant=variant) state_dict = copy.deepcopy(state_dict_v0) # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid for key in state_dict_v0: if 'attn_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp) if 'norm_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp) if 'attn_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp) if 'norm_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp) model._convert_deprecated_attention_blocks(state_dict) conv_in_weight = state_dict['conv_in.weight'] conv_out_weight = state_dict['conv_out.weight'] conv_out_bias = state_dict['conv_out.bias'] model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=True, ) # if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): # # initialize from the original SD structure # model.conv_in.weight.data[:,:4] = conv_in_weight # # whether to place all zero to new layers? # if zero_init_conv_in: # model.conv_in.weight.data[:,4:] = 0. if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): # initialize from the original SD structure model.conv_out.weight.data[-4:, ] = conv_out_weight model.conv_out.bias.data[-4:] = conv_out_bias # model.conv_out.weight.data[:,:4] = conv_out_weight # if out_channels == 8: # copy for the last 4 channels # model.conv_out.weight.data[:, 4:] = conv_out_weight if zero_init_camera_projection: for p in model.class_embedding.parameters(): torch.nn.init.zeros_(p) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model_2d( cls, model, state_dict, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) if 'proj_in' in checkpoint_key: state_dict[checkpoint_key] = torch.cat([state_dict[checkpoint_key], model_state_dict[checkpoint_key][:, -6:]], dim=1) else: del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs ================================================ FILE: core/models/unet_mv2d_condition_depth_diffusion_test.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import os import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, ) from diffusers.utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, logging, ) from diffusers import __version__ from .unet_mv2d_blocks import ( CrossAttnDownBlockMV2D, CrossAttnUpBlockMV2D, UNetMidBlockMV2DCrossAttn, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class IdentityMLP(nn.Module): def __init__(self, size): super(IdentityMLP, self).__init__() self.linear = nn.Linear(size, size) self.init_identity() def forward(self, x): return self.linear(x) def init_identity(self): # Initialize the weights to an identity matrix and biases to zero identity_matrix = torch.eye(self.linear.in_features) self.linear.weight.data.copy_(identity_matrix) self.linear.bias.data.zero_() @dataclass class UNetMV2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "CrossAttnDownBlockMV2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, num_views: int = 1, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # custom MV2D attention block elif mid_block_type == "UNetMidBlockMV2DCrossAttn": self.mid_block = UNetMidBlockMV2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up ### FIXME #up_cross_attention_dim = (None, None, None, None) reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, num_views=num_views, cd_attention_last=cd_attention_last, cd_attention_mid=cd_attention_mid, multiview_attention=multiview_attention, sparse_mv_attention=sparse_mv_attention, mvcd_attention=mvcd_attention ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ray_embedding: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNetMV2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ray_embedding = ray_embedding, ) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ray_embedding = ray_embedding, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNetMV2DConditionOutput(sample=sample) @classmethod def from_pretrained_2d( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 1, sample_size: int = 64, zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, cd_attention_mid: bool = False, multiview_attention: bool = True, sparse_mv_attention: bool = False, mvcd_attention: bool = False, in_channels: int = 4, out_channels: int = 13, **kwargs ): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) # modify config config["_class_name"] = cls.__name__ config['in_channels'] = in_channels config['out_channels'] = out_channels config['sample_size'] = sample_size # training resolution config['num_views'] = num_views config['cd_attention_last'] = cd_attention_last config['cd_attention_mid'] = cd_attention_mid config['multiview_attention'] = multiview_attention config['sparse_mv_attention'] = sparse_mv_attention config['mvcd_attention'] = mvcd_attention config["down_block_types"] = [ "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D" ] config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" config["up_block_types"] = [ "UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D" ] #config['class_embed_type'] = 'projection' if camera_embedding_type == 'e_de_da_sincos': config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 else: raise NotImplementedError # load model model_file = None if from_flax: raise NotImplementedError else: if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: if not allow_pickle: raise e pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) import copy state_dict_v0 = load_state_dict(model_file, variant=variant) state_dict = copy.deepcopy(state_dict_v0) # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid for key in state_dict_v0: if 'attn_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp) if 'norm_joint.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp) if 'attn_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp) if 'norm_joint_twice.' in key: tmp = copy.deepcopy(key) state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp) model._convert_deprecated_attention_blocks(state_dict) conv_in_weight = state_dict['conv_in.weight'] conv_out_weight = state_dict['conv_out.weight'] conv_out_bias = state_dict['conv_out.bias'] model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=True, ) # if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): # # initialize from the original SD structure # model.conv_in.weight.data[:,:4] = conv_in_weight # # whether to place all zero to new layers? # if zero_init_conv_in: # model.conv_in.weight.data[:,4:] = 0. if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): # initialize from the original SD structure model.conv_out.weight.data[-4:, ] = conv_out_weight model.conv_out.bias.data[-4:] = conv_out_bias # model.conv_out.weight.data[:,:4] = conv_out_weight # if out_channels == 8: # copy for the last 4 channels # model.conv_out.weight.data[:, 4:] = conv_out_weight if zero_init_camera_projection: for p in model.class_embedding.parameters(): torch.nn.init.zeros_(p) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model_2d( cls, model, state_dict, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) if 'proj_in' in checkpoint_key: state_dict[checkpoint_key] = torch.cat([state_dict[checkpoint_key], model_state_dict[checkpoint_key][:, -6:]], dim=1) else: del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs ================================================ FILE: core/models_LGM_compos_diffusion.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import kiui from kiui.lpips import LPIPS from core.unet_LGM_compos import UNet from core.options_latents_diffusion import Options from core.gs import GaussianRenderer from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer from typing import Optional import random import torchvision.transforms.functional as TF IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) class LGM(nn.Module): def __init__( self, opt: Options, ): super().__init__() self.opt = opt # unet self.unet = UNet( 9, 14, down_channels=self.opt.down_channels, down_attention=self.opt.down_attention, mid_attention=self.opt.mid_attention, up_channels=self.opt.up_channels, up_attention=self.opt.up_attention, ) # last conv self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again # Gaussian Renderer self.gs = GaussianRenderer(opt) # activations... self.pos_act = lambda x: x.clamp(-1, 1) self.scale_act = lambda x: 0.1 * F.softplus(x) self.opacity_act = lambda x: torch.sigmoid(x) self.rot_act = lambda x: F.normalize(x, dim=-1) self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again # LPIPS loss if self.opt.lambda_lpips > 0: self.lpips_loss = LPIPS(net='vgg') self.lpips_loss.requires_grad_(False) model_key = opt.pretrained_model_name_or_path self.unet2 = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", low_cpu_mem_usage=False,device_map=None,ignore_mismatched_sizes=True) self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder") self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") self.scheduler = DDPMScheduler.from_pretrained(model_key, subfolder="scheduler") self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.opt.weight_dtype) self.vae.requires_grad_(False) self.unet2.requires_grad_(False) #self.tokenizer.requires_grad_(False) self.text_encoder.requires_grad_(False) def state_dict(self, **kwargs): # remove lpips_loss state_dict = super().state_dict(**kwargs) for k in list(state_dict.keys()): if 'lpips_loss' in k: del state_dict[k] return state_dict def prepare_default_rays(self, device, elevation=0): from kiui.cam import orbit_camera from core.utils import get_rays cam_poses = np.stack([ orbit_camera(elevation, 0, radius=self.opt.cam_radius), orbit_camera(elevation, 90, radius=self.opt.cam_radius), orbit_camera(elevation, 180, radius=self.opt.cam_radius), orbit_camera(elevation, 270, radius=self.opt.cam_radius), ], axis=0) # [4, 4, 4] cam_poses = torch.from_numpy(cam_poses) rays_embeddings = [] for i in range(cam_poses.shape[0]): rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] rays_embeddings.append(rays_plucker) ## visualize rays for plotting figure # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] return rays_embeddings def forward_gaussians(self, images, encoder_hidden_states, data): # images: [B, 4, 9, H, W] # return: Gaussians: [B, dim_t] B, V, C, H, W = images.shape images = images.view(B*V, C, H, W) timestep = data["timesteps"].flatten(0, 1) pred_noise, blocks_sample, temb= self.unet2(images, timestep, encoder_hidden_states, return_dict=False) pred_x0 = self.pred_x0(pred_noise, timestep, images) images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5 images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False) images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].flatten(0, 1).to(self.opt.weight_dtype) ], dim=1) x = self.unet(images_256, blocks_sample, temb) # [B*4, 14, h, w] x = self.conv(x) # [B*4, 14, h, w] x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) ## visualize multi-view gaussian features for plotting figure # tmp_alpha = self.opacity_act(x[0, :, 3:4]) # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha) # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5 # kiui.vis.plot_image(tmp_img_rgb, save=True) # kiui.vis.plot_image(tmp_img_pos, save=True) x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) pos = self.pos_act(x[..., 0:3]) # [B, N, 3] opacity = self.opacity_act(x[..., 3:4]) scale = self.scale_act(x[..., 4:7]) rotation = self.rot_act(x[..., 7:11]) rgbs = self.rgb_act(x[..., 11:]) gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] return gaussians, images_512 def pred_x0( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta: float=0.0, verbose=False, ): """ predict the sampe the next step in the denoise process. """ alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x.device) alpha_prod_t = alphas_cumprod [timestep] B = alpha_prod_t.shape[0] alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1) beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 return pred_x0 def encode_prompt( self, prompt, device, prompt_embeds: Optional[torch.FloatTensor] = None, ): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape return prompt_embeds def compute_snr(self, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = self.scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def forward(self, data, step_ratio=1): # data: output of the dataloader # return: loss results = {} loss = 0 start_idx = None images = data['input'].to(self.opt.weight_dtype) # [B, 4, 9, h, W], input features num_views = images.shape[1] #ray_embedding = images[:, :, 4:] latents = images.flatten(0,1) latent = latents[:,:4] bsz, c, h, w = latent.shape # timesteps timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bsz // num_views,), device=images.device) timesteps_pred = timesteps.repeat_interleave(self.opt.num_views) timesteps = timesteps.repeat_interleave(num_views) timesteps = timesteps.long() if(random.random() < 0.7): start_idx = torch.randint(0,4, (1,)).item() timesteps[start_idx ::num_views] = 0 timesteps_pred[start_idx ::self.opt.num_views] = 0 if(random.random() < 0.7): prompt = data["prompt"] # prompt = [prompt[i][j] for j in range(len(prompt[0])) for i in range(len(prompt))] # encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) prompt = [prompt[0][i] for i in range(len(prompt[0]))] #print(prompt) encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) encoder_hidden_states = encoder_hidden_states.flatten(0,1) else: prompt = ['']*images.shape[0] encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) encoder_hidden_states = encoder_hidden_states.flatten(0,1) noise = torch.randn_like(latent).to(device=images.device) noisy_latents = self.scheduler.add_noise(latent, noise, timesteps).to(device=images.device) data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w) data['timesteps'] = timesteps.reshape(bsz // num_views, num_views) snr = self.compute_snr(timesteps_pred) mse_loss_weights = torch.stack([snr, self.opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] # use the first view to predict gaussians images = data['noisy_latents'] gaussians, noise_images = self.forward_gaussians(images, encoder_hidden_states, data) # [B, N, 14] results['gaussians'] = gaussians # always use white bg bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) # use the other views for rendering and supervision results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) pred_images = results['image'].to(self.opt.weight_dtype) # [B, V, C, output_size, output_size] pred_alphas = results['alpha'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size] results['images_pred'] = pred_images results['alphas_pred'] = pred_alphas gt_images = data['images2_output'].to(self.opt.weight_dtype) # [B, V, 3, output_size, output_size], ground-truth novel views gt_masks = data['masks_output'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size], ground-truth masks gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1).to(self.opt.weight_dtype) * (1 - gt_masks) loss_mse_image = F.mse_loss(pred_images.flatten(0,1), gt_images.flatten(0,1), reduction="none") loss_mse_alpha = F.mse_loss(pred_alphas.flatten(0,1), gt_masks.flatten(0,1), reduction="none") loss_mse_image = (loss_mse_image.mean(dim=list(range(1, len(loss_mse_image.shape)))) * mse_loss_weights).mean() loss_mse_alpha = (loss_mse_alpha.mean(dim=list(range(1, len(loss_mse_alpha.shape)))) * mse_loss_weights).mean() results['loss_mse_image'] = loss_mse_image results['loss_mse_alpha'] = loss_mse_alpha loss_mse = loss_mse_image + loss_mse_alpha results['loss_mse'] = loss_mse loss = loss + loss_mse results['gt_noise'] = noise_images.reshape(bsz // num_views, num_views, 3, 512, 512) if self.opt.lambda_lpips > 0 and step_ratio > 0: loss_lpips = self.lpips_loss( # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, # downsampled to at most 256 to reduce memory cost F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), ) lpips_loss_weights = torch.ones_like(mse_loss_weights) if start_idx is not None: lpips_loss_weights[start_idx::self.opt.num_views] = 5.0 loss_lpips = (loss_lpips.mean(dim=list(range(1, len(loss_lpips.shape)))) * lpips_loss_weights).mean() results['loss_lpips'] = loss_lpips #loss = loss + self.opt.lambda_lpips * (step_ratio-0.25) * loss_lpips loss = loss + self.opt.lambda_lpips * loss_lpips results['loss'] = loss # metric with torch.no_grad(): psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) results['psnr'] = psnr return results ================================================ FILE: core/models_LGM_compos_diffusion_validate_inversion_2_masa.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import kiui from kiui.lpips import LPIPS from core.unet_LGM_compos import UNet from core.options_latents_diffusion import Options from core.gs import GaussianRenderer from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer from typing import Optional import random import torchvision.transforms.functional as TF import tqdm from core.control import ControlNetPipeline from core.masactrl import MutualSelfAttention3DControl from core.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) class LGM(nn.Module): def __init__( self, opt: Options, ): super().__init__() self.opt = opt # unet self.unet = UNet( 9, 14, down_channels=self.opt.down_channels, down_attention=self.opt.down_attention, mid_attention=self.opt.mid_attention, up_channels=self.opt.up_channels, up_attention=self.opt.up_attention, ).to(self.opt.weight_dtype) # last conv self.conv = nn.Conv2d(14, 14, kernel_size=1).to(self.opt.weight_dtype) # NOTE: maybe remove it if train again # Gaussian Renderer self.gs = GaussianRenderer(opt) # activations... self.pos_act = lambda x: x.clamp(-1, 1) self.scale_act = lambda x: 0.1 * F.softplus(x) self.opacity_act = lambda x: torch.sigmoid(x) self.rot_act = lambda x: F.normalize(x, dim=-1) self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again # LPIPS loss if self.opt.lambda_lpips > 0: self.lpips_loss = LPIPS(net='vgg') self.lpips_loss.requires_grad_(False) model_key = opt.pretrained_model_name_or_path self.unet2 = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", low_cpu_mem_usage=False,device_map=None,ignore_mismatched_sizes=True).to(self.opt.weight_dtype) self.unet3 = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", low_cpu_mem_usage=False,device_map=None,ignore_mismatched_sizes=True).to(self.opt.weight_dtype) self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.opt.weight_dtype) self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") self.scheduler = DDPMScheduler.from_pretrained(model_key, subfolder="scheduler") self.scheduler2 = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") self.test_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) #self.pipe = MasaCtrlPipeline.from_pretrained(model_key, scheduler=self.test_scheduler) self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.opt.weight_dtype) self.vae.requires_grad_(False) self.unet2.requires_grad_(False) self.unet3.requires_grad_(False) #self.tokenizer.requires_grad_(False) self.text_encoder.requires_grad_(False) self.steps = 2 self.layer = 10 #self.masa_editor = MutualSelfAttention3DControl(step, layer, total_steps=30) self.base_editor = AttentionBase() def state_dict(self, **kwargs): # remove lpips_loss state_dict = super().state_dict(**kwargs) for k in list(state_dict.keys()): if 'lpips_loss' in k: del state_dict[k] return state_dict def prepare_default_rays(self, device, elevation=0, proj_matrix=None): from kiui.cam import orbit_camera from core.utils import get_rays cam_poses = np.stack([ orbit_camera(elevation, 0, radius=self.opt.cam_radius, opengl=True), orbit_camera(elevation, 90, radius=self.opt.cam_radius, opengl=True), orbit_camera(elevation, 180, radius=self.opt.cam_radius, opengl=True), orbit_camera(elevation, 270, radius=self.opt.cam_radius, opengl=True), ], axis=0) # [4, 4, 4] cam_poses = torch.from_numpy(cam_poses) rays_embeddings = [] for i in range(cam_poses.shape[0]): rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_ray_size, self.opt.input_ray_size, self.opt.fovy) # [h, w, 3] rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] rays_embeddings.append(rays_plucker) ## visualize rays for plotting figure # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] cam_poses[:, :3, 1:3] *= -1 cam_poses = cam_poses.to(device) cam_view = torch.inverse(cam_poses).transpose(1, 2) cam_view_proj = cam_view @ proj_matrix cam_pos = - cam_poses[:, :3, 3] return rays_embeddings, cam_view, cam_view_proj, cam_pos def prepare_default_rays_zero123(self, device, elevation=0, proj_matrix=None): from kiui.cam import orbit_camera from core.utils import get_rays cam_poses = np.stack([ orbit_camera(0, 0, radius=self.opt.cam_radius, opengl=True), orbit_camera(-10, 90, radius=self.opt.cam_radius, opengl=True), orbit_camera(-10, 210, radius=self.opt.cam_radius, opengl=True), orbit_camera(20, 270, radius=self.opt.cam_radius, opengl=True), ], axis=0) # [4, 4, 4] # cam_poses = np.stack([ # orbit_camera(0, 0, radius=self.opt.cam_radius, opengl=True), # orbit_camera(0, 120, radius=self.opt.cam_radius, opengl=True), # orbit_camera(0, 240, radius=self.opt.cam_radius, opengl=True), # orbit_camera(-30, 300, radius=self.opt.cam_radius, opengl=True), # ], axis=0) # [4, 4, 4] cam_poses = torch.from_numpy(cam_poses) rays_embeddings = [] for i in range(cam_poses.shape[0]): rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_ray_size, self.opt.input_ray_size, self.opt.fovy) # [h, w, 3] rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] rays_embeddings.append(rays_plucker) ## visualize rays for plotting figure # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] cam_poses[:, :3, 1:3] *= -1 cam_poses = cam_poses.to(device) cam_view = torch.inverse(cam_poses).transpose(1, 2) cam_view_proj = cam_view @ proj_matrix cam_pos = - cam_poses[:, :3, 3] return rays_embeddings, cam_view, cam_view_proj, cam_pos def unet_step( self, model_output: torch.FloatTensor, timestep, x: torch.FloatTensor, eta: float=0.0, verbose=False, ): """ predict the sampe the next step in the denoise process. """ prev_timestep = timestep - self.test_scheduler.config.num_train_timesteps // self.test_scheduler.num_inference_steps alpha_prod_t = self.test_scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = self.test_scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.test_scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir return x_prev def forward_gaussians(self, images, encoder_hidden_states, data, uncon_encoder_hidden_states=None): # images: [B, 4, 9, H, W] # return: Gaussians: [B, dim_t] B, V, C, H, W = images.shape images = images.view(B*V, C, H, W) timestep = data["timesteps"].flatten(0, 1) pred_noise, blocks_sample, temb= self.unet2(images, timestep, encoder_hidden_states, return_dict=False) if uncon_encoder_hidden_states is not None: uncon_pred_noise, _, _= self.unet3(images, timestep, uncon_encoder_hidden_states, return_dict=False) pred_noise = uncon_pred_noise + 3 * (pred_noise - uncon_pred_noise) # print(3.5) if pred_noise.shape[0] == 5: pred_x0 = self.pred_x0(pred_noise[:4], timestep[:4], images[:4]) masa_latent = self.unet_step(pred_noise[4:,], timestep[4].item(), images[4:]) temb = temb[:4] blocks_sample = [i[:4] for i in blocks_sample] else: pred_x0 = self.pred_x0(pred_noise, timestep, images) images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5 images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False) images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].to(self.opt.weight_dtype) ], dim=1) x = self.unet(images_256, blocks_sample, temb) # [B*4, 14, h, w] x = self.conv(x) # [B*4, 14, h, w] x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) ## visualize multi-view gaussian features for plotting figure # tmp_alpha = self.opacity_act(x[0, :, 3:4]) # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha) # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5 # kiui.vis.plot_image(tmp_img_rgb, save=True) # kiui.vis.plot_image(tmp_img_pos, save=True) x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) pos = self.pos_act(x[..., 0:3]) # [B, N, 3] opacity = self.opacity_act(x[..., 3:4]) scale = self.scale_act(x[..., 4:7]) rotation = self.rot_act(x[..., 7:11]) rgbs = self.rgb_act(x[..., 11:]) gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] return gaussians, images_512, masa_latent def pred_x0( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta: float=0.0, verbose=False, ): """ predict the sampe the next step in the denoise process. """ alphas_cumprod = self.test_scheduler.alphas_cumprod.to(device=x.device) alpha_prod_t = alphas_cumprod [timestep] B = alpha_prod_t.shape[0] alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1) beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 return pred_x0 def step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta: float=0.0, verbose=False, ): """ predict the sampe the next step in the denoise process. """ prev_timestep = timestep - self.test_scheduler.config.num_train_timesteps // self.test_scheduler.num_inference_steps prev_timestep[timestep==0] = 0 alphas_cumprod = self.test_scheduler.alphas_cumprod.to(device=x.device) alpha_prod_t = alphas_cumprod [timestep] #alpha_prod_t_prev = self.test_scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.test_scheduler.final_alpha_cumprod alpha_prod_t_prev = torch.where(prev_timestep >0, self.test_scheduler.alphas_cumprod[prev_timestep], self.test_scheduler.final_alpha_cumprod).to(device=x.device) B = alpha_prod_t.shape[0] alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1) alpha_prod_t_prev = alpha_prod_t_prev.view(B, 1, 1, 1) beta_prod_t = 1 - alpha_prod_t #pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_noise = (x - alpha_prod_t**0.5 * model_output) / beta_prod_t**0.5 pred_dir = (1 - alpha_prod_t_prev)**0.5 * pred_noise #x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir x_prev = alpha_prod_t_prev**0.5 * model_output + pred_dir return x_prev def encode_prompt( self, prompt, device, prompt_embeds: Optional[torch.FloatTensor] = None, ): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape return prompt_embeds def compute_snr(self, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = self.scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def forward(self, data, step_ratio=1): # data: output of the dataloader # return: loss results = {} loss = 0 images = data['input'].to(self.opt.weight_dtype) # [B, 4, 9, h, W], input features num_views = images.shape[1] #ray_embedding = images[:, :, 4:] latents = images.flatten(0,1) latent = latents[:,:4] bsz, c, h, w = latent.shape # timesteps timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bsz // num_views,), device=images.device) timesteps_pred = timesteps.repeat_interleave(self.opt.num_views) timesteps = timesteps.repeat_interleave(num_views) timesteps = timesteps.long() if(random.random() < 0.7): timesteps[::num_views] = 0 timesteps_pred[::self.opt.num_views] = 0 if(random.random() < 0.7): prompt = data["prompt"] # prompt = [prompt[i][j] for j in range(len(prompt[0])) for i in range(len(prompt))] # encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) prompt = [prompt[0][i] for i in range(len(prompt[0]))] #print(prompt) encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) encoder_hidden_states = encoder_hidden_states.flatten(0,1) else: prompt = ['']*images.shape[0] encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) encoder_hidden_states = encoder_hidden_states.flatten(0,1) noise = torch.randn_like(latent).to(device=images.device) noisy_latents = self.scheduler.add_noise(latent, noise, timesteps).to(device=images.device) data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w) data['timesteps'] = timesteps.reshape(bsz // num_views, num_views) snr = self.compute_snr(timesteps_pred) mse_loss_weights = torch.stack([snr, self.opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] # use the first view to predict gaussians images = data['noisy_latents'] gaussians, noise_images = self.forward_gaussians(images, encoder_hidden_states, data) # [B, N, 14] results['gaussians'] = gaussians # always use white bg bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) # use the other views for rendering and supervision results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) pred_images = results['image'].to(self.opt.weight_dtype) # [B, V, C, output_size, output_size] pred_alphas = results['alpha'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size] results['images_pred'] = pred_images results['alphas_pred'] = pred_alphas gt_images = data['images2_output'].to(self.opt.weight_dtype) # [B, V, 3, output_size, output_size], ground-truth novel views gt_masks = data['masks_output'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size], ground-truth masks gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1).to(self.opt.weight_dtype) * (1 - gt_masks) loss_mse_image = F.mse_loss(pred_images.flatten(0,1), gt_images.flatten(0,1), reduction="none") loss_mse_alpha = F.mse_loss(pred_alphas.flatten(0,1), gt_masks.flatten(0,1), reduction="none") loss_mse_image = (loss_mse_image.mean(dim=list(range(1, len(loss_mse_image.shape)))) * mse_loss_weights).mean() loss_mse_alpha = (loss_mse_alpha.mean(dim=list(range(1, len(loss_mse_alpha.shape)))) * mse_loss_weights).mean() results['loss_mse_image'] = loss_mse_image results['loss_mse_alpha'] = loss_mse_alpha loss_mse = loss_mse_image + loss_mse_alpha results['loss_mse'] = loss_mse loss = loss + loss_mse results['gt_noise'] = noise_images.reshape(bsz // num_views, num_views, 3, 512, 512) if self.opt.lambda_lpips > 0 and step_ratio > 0: loss_lpips = self.lpips_loss( # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, # downsampled to at most 256 to reduce memory cost F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), ) lpips_loss_weights = torch.ones_like(mse_loss_weights) if timesteps[0] == 0: lpips_loss_weights[::self.opt.num_views] = 5.0 loss_lpips = (loss_lpips.mean(dim=list(range(1, len(loss_lpips.shape)))) * lpips_loss_weights).mean() results['loss_lpips'] = loss_lpips #loss = loss + self.opt.lambda_lpips * (step_ratio-0.25) * loss_lpips loss = loss + self.opt.lambda_lpips * loss_lpips results['loss'] = loss # metric with torch.no_grad(): psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) results['psnr'] = psnr return results def next_step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0., verbose=False ): """ Inverse sampling for DDIM Inversion """ if verbose: print("timestep: ", timestep) next_step = timestep timestep = min(timestep - self.test_scheduler.config.num_train_timesteps // self.test_scheduler.num_inference_steps, 999) alpha_prod_t = self.test_scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.test_scheduler.final_alpha_cumprod alpha_prod_t_next = self.test_scheduler.alphas_cumprod[next_step] beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir return x_next, pred_x0 # def invert(self, image, encoder_hidden_states): # noisy_latent = image # for i, t in enumerate(reversed(self.test_scheduler.timesteps)): # model_inputs = noisy_latent # noise_pred = self.unet2(model_inputs, t, encoder_hidden_states=encoder_hidden_states).sample # noisy_latent, pred_x0 = self.next_step(noise_pred, t, noisy_latent) # a = (self.vae.decode(pred_x0.detach()/ 0.18215).sample +1)*0.5 # b = a.clamp(0,1).float().reshape(8, 4, 3, 512, 512).detach().to(torch.float).cpu().numpy() # c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3) # kiui.write_image(f'{i}_2.jpg', c1) # return noisy_latent @torch.no_grad() def image2latent(self, image): #DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # if type(image) is Image: # image = np.array(image) # image = torch.from_numpy(image).float() / 127.5 - 1 # image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE) # input image density range [-1, 1] latents = self.vae.encode(image)['latent_dist'].mean latents = latents * 0.18215 return latents @torch.no_grad() def invert( self, image: torch.Tensor, prompt="", # num_inference_steps=50, # guidance_scale=7.5, # eta=0.0, # return_intermediates=False, **kwds): """ invert a real image into noise map with determinisc DDIM inversion """ DEVICE = image.device batch_size = image.shape[0] # if isinstance(prompt, list): # if batch_size == 1: # image = image.expand(len(prompt), -1, -1, -1) # elif isinstance(prompt, str): # if batch_size > 1: # prompt = [prompt] * batch_size prompt = [prompt] * batch_size # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] print("input text embeddings :", text_embeddings.shape) # define initial latents latents = self.image2latent(image) start_latents = latents # print(latents) # exit() # unconditional embedding for classifier free guidance # if guidance_scale > 1.: # max_length = text_input.input_ids.shape[-1] # unconditional_input = self.tokenizer( # [""] * batch_size, # padding="max_length", # max_length=77, # return_tensors="pt" # ) # unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] # text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0) # print("latents shape: ", latents.shape) # interative sampling #self.scheduler.set_timesteps(num_inference_steps) print("Valid timesteps: ", reversed(self.test_scheduler.timesteps)) # print("attributes: ", self.scheduler.__dict__) latents_list = [latents] pred_x0_list = [latents] for i, t in enumerate(reversed(self.test_scheduler.timesteps)): # if guidance_scale > 1.: # model_inputs = torch.cat([latents] * 2) model_inputs = latents # predict the noise noise_pred = self.unet3(model_inputs, t, encoder_hidden_states=text_embeddings).sample # if guidance_scale > 1.: # noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) # noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t-1 -> x_t latents, pred_x0 = self.next_step(noise_pred, t, latents) # a = (self.vae.decode(pred_x0.detach()/ 0.18215).sample +1)*0.5 # b = a.clamp(0,1).float().reshape(8, 4, 3, 512, 512).detach().to(torch.float).cpu().numpy() # c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3) a = (self.vae.decode(pred_x0[:1].detach()/ 0.18215).sample +1)*0.5 b = a.clamp(0,1).float().reshape(1, 3, 512, 512).detach().to(torch.float).cpu().numpy() c1 = b.transpose(0, 2, 3, 1) kiui.write_image(f'{self.opt.workspace}/{i}_7.jpg', c1) latents_list.append(latents) pred_x0_list.append(pred_x0) # if return_intermediates: # # return the intermediate laters during inversion # # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list] # return latents, latents_list return latents def validate(self, data, num_inference_steps=30, single_image=True): results = {} self.test_scheduler.set_timesteps(num_inference_steps) self.opt.weight_dtype = torch.bfloat16 data['input'] = self.vae.encode(data['images2_output']*2 -1).latent_dist.mode().detach() *0.18215 data['input'] = data['input'].unsqueeze(0) images = data['input'].to(self.opt.weight_dtype) # [B, 4, 9, h, W], input features self.masa_editor = MutualSelfAttention3DControl(self.steps, self.layer, total_steps=num_inference_steps) self.masa_editor.reset() regiter_attention_editor_diffusers(self.unet2, self.masa_editor) #self.test_scheduler = self.test_scheduler.to(images.device) num_views = images.shape[1] #ray_embedding = images[:, :, 4:] latents = images.flatten(0,1) latent = latents[:,:4] bsz, c, h, w = latent.shape gt_images = data['images2_output'].to(self.opt.weight_dtype) noise = torch.randn_like(latent).to(device=images.device) data['noisy_latents'] = noise.reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype) prompt = ['']*images.shape[0] uncon_encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) uncon_encoder_hidden_states = uncon_encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) uncon_encoder_hidden_states = uncon_encoder_hidden_states.flatten(0,1) prompt = data["prompt"]*(images.shape[0]) encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) encoder_hidden_states = encoder_hidden_states.flatten(0,1) encoder_hidden_states[4:] = uncon_encoder_hidden_states[4:] img_latent = self.vae.encode(gt_images*2 -1).latent_dist.mode().detach() *0.18215 img = (self.vae.decode(img_latent.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5 #img = gt_images data['noisy_latents'] = self.invert(img*2-1).reshape(bsz // num_views, num_views, c, h, w) # timesteps # timesteps = torch.ones((bsz // num_views,), device=images.device)* 481 # timesteps_pred = timesteps.repeat_interleave(self.opt.num_views) # timesteps = timesteps.repeat_interleave(num_views) # timesteps = timesteps.long() # # timesteps[::num_views] = 0 # # timesteps_pred[::self.opt.num_views] = 0 # # add noise # noise = torch.randn_like(latent).to(device=images.device) # noisy_latents = self.test_scheduler.add_noise(latent, noise, timesteps).to(device=images.device) # data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w) # data['timesteps'] = timesteps.reshape(bsz // num_views, num_views) if single_image is True: data['noisy_latents'][:, :1] = images[:, :1] for i, t in enumerate(self.test_scheduler.timesteps): print(i, t) # timesteps = torch.ones((bsz // num_views,), device=images.device)* t # timesteps_pred = timesteps.repeat_interleave(self.opt.num_views) # timesteps = timesteps.repeat_interleave(num_views) # timesteps = timesteps.long() # # timesteps[::num_views] = 0 # # timesteps_pred[::self.opt.num_views] = 0 # # add noise # noise = torch.randn_like(latent).to(device=images.device) # noisy_latents = self.test_scheduler.add_noise(latent, noise, timesteps).to(device=images.device) # data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w) timesteps = t.repeat(bsz // num_views) #timesteps_pred = timesteps.repeat_interleave(self.opt.num_views) timesteps = timesteps.repeat_interleave(num_views) timesteps = timesteps.long() #if(random.random() < 0.9): if single_image is True: timesteps[::num_views] = 0 #timesteps_pred[::self.opt.num_views] = 0 # add noise # noise = torch.randn_like(latent).to(device=images.device) # noisy_latents = self.scheduler.add_noise(latent, noise, timesteps).to(device=images.device) # data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w) data['timesteps'] = timesteps.reshape(bsz // num_views, num_views).to(device=images.device) timesteps_cpu = timesteps.reshape(bsz // num_views, num_views) ### FIXME #timesteps_pred = torch.cat([data["timesteps"], 300 * torch.ones(self.opt.batch_size, self.opt.num_views-data['timesteps'].shape[1]).long().to(timesteps.device)],dim=1).flatten(0,1) #snr = self.compute_snr(timesteps_pred) #mse_loss_weights = torch.stack([snr, opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] / snr #mse_loss_weights = torch.stack([snr, self.opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] # use the first view to predict gaussians # prompt = ['']*images.shape[0] # uncon_encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype) # uncon_encoder_hidden_states = uncon_encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1) # uncon_encoder_hidden_states = uncon_encoder_hidden_states.flatten(0,1) # uncon_encoder_hidden_states = None images = data['noisy_latents'] # img = (self.vae.decode(images.flatten(0,1).to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5 # b = img.unsqueeze(0).clamp(0,1).detach().to(torch.float).cpu().numpy() # c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3) # kiui.write_image(f'{self.opt.workspace}/{i}_1_noise.jpg', c1) gaussians, noise_images, masa_latent = self.forward_gaussians(images, encoder_hidden_states, data, uncon_encoder_hidden_states) # [B, N, 14] results['gaussians'] = gaussians bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) # use the other views for rendering and supervision results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color, scale_modifier=1) # pred_images = results['image'] # [B, V, C, output_size, output_size] pred_alphas = results['alpha'].to(self.opt.weight_dtype) pred_images = results['image'].to(self.opt.weight_dtype) #pred_images = pred_images + self.white_latent.to(pred_images.device)*(1-pred_alphas) #data['noisy_latents'] = self.step((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, timesteps_cpu.flatten(0, 1), data['noisy_latents'].flatten(0, 1)).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype) data['noisy_latents'] = torch.cat([self.step((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, timesteps_cpu[:,:4].flatten(0, 1), data['noisy_latents'][:,:4].flatten(0, 1)), masa_latent]).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype) # if t > 0: # #data['noisy_latents'] = self.step((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, timesteps_cpu.flatten(0, 1), noise).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype) # data['noisy_latents'] = self.scheduler.add_noise((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, noise, timesteps-1).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype) if single_image is True: data['noisy_latents'][:, :1] = images[:, :1] #a = (self.vae.decode(pred_images.detach().to(dtype=torch.bfloat16).flatten(0,1)/ 0.18215).sample +1)*0.5 b = pred_images.detach().to(torch.float).cpu().numpy() c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3) kiui.write_image(f'{self.opt.workspace}/{i}_2.jpg', c1) #a = (self.vae.decode(data['noisy_latents'].detach().flatten(0,1)/ 0.18215).sample +1)*0.5 b = noise_images.clamp(0,1).float().reshape(1, 4, 3, 512, 512).detach().to(torch.float).cpu().numpy() c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3) kiui.write_image(f'{self.opt.workspace}/{i}_2_noise.jpg', c1) return results, gaussians ================================================ FILE: core/options_latents_diffusion.py ================================================ import tyro from dataclasses import dataclass from typing import Tuple, Literal, Dict, Optional @dataclass class Options: ### model # Unet image input size gradient_checkpointing: bool = False enable_xformers_memory_efficient_attention: bool = False pretrained_model_name_or_path: str = "/remote-home1/yeyang/aigc/model/stable-diffusion-v1-5" input_size: int = 256 input_ray_size: int = 256 # Unet definition down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) mid_attention: bool = True up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) up_attention: Tuple[bool, ...] = (True, True, True, False) # Unet output size, dependent on the input_size and U-Net structure! splat_size: int = 64 # gaussian render size output_size: int = 256 ### dataset # data mode (only support s3 now) data_mode: Literal['s3'] = 's3' data_path: str = '/remote-home1/yeyang/aigc/dataset2' json_path:str = '/remote-home1/yeyang/aigc/dataset1' # fovy of the dataset fovy: float = 39.6 # camera near plane znear: float = 0.01 # camera far plane zfar: float = 1000 # number of all views (input + output) num_views: int = 12 # number of views num_input_views: int = 4 # camera radius cam_radius: float = 1.5 # to better use [-1, 1]^3 space # num workers num_workers: int = 16 snr_gamma: int = 5 ### training # workspace workspace: str = './workspace' workspace1: Optional[str] = None # resume resume: Optional[str] = None # batch size (per-GPU) batch_size: int = 8 # gradient accumulation gradient_accumulation_steps: int = 1 # training epochs num_epochs: int = 30 # lpips loss weight lambda_lpips: float = 1.0 ##TZY # gradient clip gradient_clip: float = 1.0 # mixed precision mixed_precision: str = 'bf16' # learning rate lr: float = 5e-5 # augmentation prob for grid distortion prob_grid_distortion: float = 0.5 # augmentation prob for camera jitter prob_cam_jitter: float = 0.5 ### testing # test image path test_path: Optional[str] = None ### misc # nvdiffrast backend setting force_cuda_rast: bool = False # render fancy video with gaussian scaling effect fancy_video: bool = False checkpoints_total_limit: int = 3 # all the default settings config_defaults: Dict[str, Options] = {} config_doc: Dict[str, str] = {} # config_doc['lrm'] = 'the default settings for LGM' # config_defaults['lrm'] = Options() config_doc['small'] = 'small model with lower resolution Gaussians' config_defaults['small'] = Options( input_size=256, splat_size=64, output_size=256, batch_size=8, gradient_accumulation_steps=1, mixed_precision='bf16', ) config_doc['big'] = 'big model with higher resolution Gaussians' config_defaults['big'] = Options( input_size=64, up_channels=(1024, 1024, 512, 256, 128), # one more decoder up_attention=(True, True, True, False, False), splat_size=32, output_size=64, # render & supervise Gaussians at a higher resolution. batch_size=96, num_views=10, gradient_accumulation_steps=1, mixed_precision='bf16', ) config_doc['big_latent'] = 'big model with higher resolution Gaussians' config_defaults['big_latent'] = Options( input_size=64, down_channels=(256, 512, 1024, 1024), down_attention=(True, True, True, False), up_channels=(1024, 1024, 512, 256), up_attention=(False, True, True, True), splat_size=64, output_size=64, # render & supervise Gaussians at a higher resolution. batch_size = 2, # 2 num_views= 8, gradient_accumulation_steps= 6, # 16 mixed_precision='bf16', ) config_doc['big_latent_sd'] = 'big model with higher resolution Gaussians' config_defaults['big_latent_sd'] = Options( gradient_checkpointing = True, enable_xformers_memory_efficient_attention = True, lr = 1e-4, #lambda_lpips = 0.5, lambda_lpips = 2, input_size=64, down_channels=(256, 512, 1024, 1024), down_attention=(True, True, True, False), up_channels=(1024, 1024, 512, 256), up_attention=(False, True, True, True), splat_size=64, output_size=64, # render & supervise Gaussians at a higher resolution. batch_size = 2, # 2 num_views= 8, gradient_accumulation_steps= 6, # 16 mixed_precision='bf16', ) config_defaults['big_latent_sd_diffusion'] = Options( gradient_checkpointing = True, enable_xformers_memory_efficient_attention = True, lr = 1e-4, lambda_lpips = 0.5, #lambda_lpips = 2, input_size=64, down_channels=(256, 512, 1024, 1024), down_attention=(True, True, True, False), up_channels=(1024, 1024, 512, 256), up_attention=(False, True, True, True), splat_size=64, output_size=64, # render & supervise Gaussians at a higher resolution. batch_size = 2, # 2 num_views= 8, gradient_accumulation_steps= 2, # 16 mixed_precision='bf16', num_epochs = 50, ) config_defaults['big_latent_sd_diffusion_insert'] = Options( gradient_checkpointing = True, enable_xformers_memory_efficient_attention = True, lr = 1e-4, lambda_lpips = 0.5, #lambda_lpips = 2, input_size=64, #resume= "/remote-home1/yeyang/aigc/models/models--ashawkey--LGM/snapshots/1c28a2fd3bb1982414f722503ae862bdbb82636c/model_fp16_fixrot.safetensors", resume= 'workspace_1e-4_latent_diffusion_unet_LGM_insert3/model.safetensors', up_channels=(1024, 1024, 512, 256, 128), up_attention=(True, True, True, False, False), splat_size=128, output_size= 512, # render & supervise Gaussians at a higher resolution. batch_size = 8, # 2 num_views= 8, gradient_accumulation_steps= 1, # 16 mixed_precision='bf16', ) config_defaults['big_latent_sd_diffusion_compose'] = Options( gradient_checkpointing = True, enable_xformers_memory_efficient_attention = True, lr = 1e-4, lambda_lpips = 0.5, #lambda_lpips = 2, input_size=64, resume= "/remote-home1/yeyang/aigc/models/models--ashawkey--LGM/snapshots/1c28a2fd3bb1982414f722503ae862bdbb82636c/model_fp16_fixrot.safetensors", #resume= 'workspace_1e-4_latent_diffusion_unet_LGM_compose_text/model.safetensors', up_channels=(1024, 1024, 512, 256, 128), up_attention=(True, True, True, False, False), splat_size=128, output_size= 512, # render & supervise Gaussians at a higher resolution. batch_size = 8, # 2 num_views= 8, gradient_accumulation_steps= 1, # 16 mixed_precision='bf16', ) config_doc['big_latent_lpips'] = 'big model with higher resolution Gaussians' config_defaults['big_latent_lpips'] = Options( input_size=64, down_channels=(256, 512, 1024, 1024), down_attention=(True, True, True, False), up_channels=(1024, 1024, 512, 256), up_attention=(False, True, True, True), splat_size=64, output_size=64, # render & supervise Gaussians at a higher resolution. batch_size=6, # 2 num_views=10, gradient_accumulation_steps=16, # 16 mixed_precision='bf16', ) # config_doc['big_latent'] = 'big model with higher resolution Gaussians' # config_defaults['big_latent'] = Options( # input_size=64, # down_channels=(256, 512, 1024, 1024), # down_attention=(True, True, True, False), # up_channels=(1024, 1024, 512, 256), # up_attention=(False, True, True, True), # splat_size=64, # output_size=64, # render & supervise Gaussians at a higher resolution. # batch_size=15, # 2 # num_views=10, # gradient_accumulation_steps=4, # 16 # mixed_precision='bf16', # ) config_doc['tiny'] = 'tiny model for ablation' config_defaults['tiny'] = Options( input_size=256, down_channels=(32, 64, 128, 256, 512), down_attention=(False, False, False, False, True), up_channels=(512, 256, 128), up_attention=(True, False, False, False), splat_size=64, output_size=256, batch_size=16, num_views=8, gradient_accumulation_steps=1, mixed_precision='bf16', ) AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) ================================================ FILE: core/provider_Gobjaverse_latent_diffusion_insert.py ================================================ import os import cv2 import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from torch.utils.data import Dataset import json import kiui from core.options_latents_diffusion import Options from core.utils import get_rays, grid_distortion, orbit_camera_jitter import tyro from core.options import AllConfigs # import debugpy; debugpy.connect(("localhost", 5677)) IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) class GobjaverseDataset(Dataset): def _warn(self): raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') def __init__(self, opt: Options, training=True): self.opt = opt self.training = training # TODO: remove this barrier # self._warn() # TODO: load the list of objects for training self.items = [] with open('/remote-home1/yeyang/aigc/gobj_merged.json', 'r') as f: self.items = json.load(f) with open('/remote-home1/yeyang/aigc/text_captions_cap3d.json', 'r') as cap: self.captions = json.load(cap) # naive split if self.training: self.items = self.items[:-self.opt.batch_size] else: self.items = self.items[-self.opt.batch_size:] #self.items = self.items[:self.opt.batch_size] # default camera intrinsics self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) self.proj_matrix[0, 0] = 1 / self.tan_half_fov self.proj_matrix[1, 1] = 1 / self.tan_half_fov self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) self.proj_matrix[2, 3] = 1 def __len__(self): return len(self.items) #return 250 def __getitem__(self, idx): uid = self.items[idx] results = {} results["prompt"] = [self.captions[uid]] *self.opt.num_input_views # load num_views images images = [] images2 = [] masks = [] cam_poses = [] vid_cnt = 0 # TODO: choose views, based on your rendering settings if self.training: # input views are in (36, 72), other views are randomly selected #input = np.random.permutation(np.arange(27, 39))[:self.opt.num_input_views].tolist() input_1 = np.random.permutation(np.arange(27, 30))[:1].tolist() input_2 = np.random.permutation(np.arange(30, 33))[:1].tolist() input_3 = np.random.permutation(np.arange(33, 36))[:1].tolist() input_4 = np.random.permutation(np.arange(36, 39))[:1].tolist() render = np.random.permutation(np.append(np.arange(1, 25), np.arange(27, 39))).tolist() #vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist()’ vids = input_1 + input_2 + input_3 + input_4 + render else: # fixed views vids = np.arange(27, 39, 4).tolist() + np.arange(1, 39).tolist() #vids = [27, 30, 33, 36] + np.random.permutation(np.append(np.arange(1, 25), np.arange(27, 39))).tolist() #vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist() for vid in vids: #if not os.path.exists(os.path.join(self.opt.data_path, uid, f'{vid:05d}', f'{vid:05d}.pt')): #uid = "1/15039" image_path = os.path.join(self.opt.data_path, uid, f'{vid:05d}', f'{vid:05d}.pt') #mask_path = os.path.join(self.opt.data_path, uid, f'{vid:05d}', f'{vid:05d}_mask.pt') camera_path = os.path.join(self.opt.json_path, uid, f'{vid:05d}', f'{vid:05d}.json') image2_path = os.path.join(self.opt.json_path, uid, f'{vid:05d}', f'{vid:05d}.png') try: # TODO: load data (modify self.client here) image2 = torch.from_numpy(cv2.imread(image2_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) image = torch.load(image_path) #mask = torch.load(mask_path) with open(camera_path, 'r', encoding='utf8') as f: meta = json.load(f) except Exception as e: print(f'[WARN] dataset {uid} {vid}: {e}') continue # TODO: you may have a different camera system # blender world + opencv cam --> opengl world & cam c2w = np.eye(4) c2w[:3, 0] = np.array(meta['x']) c2w[:3, 1] = np.array(meta['y']) c2w[:3, 2] = np.array(meta['z']) c2w[:3, 3] = np.array(meta['origin']) c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) c2w[1] *= -1 c2w[[1, 2]] = c2w[[2, 1]] c2w[:3, 1:3] *= -1 # invert up and forward direction # scale up radius to fully use the [-1, 1]^3 space! #c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale image2 = image2.permute(2, 0, 1) # [4, 512, 512] mask2 = image2[3:4] # [1, 512, 512] image2 = image2[:3] * mask2 + (1 - mask2) # [3, 512, 512], to white bg image2 = image2[[2,1,0]].contiguous() # bgr to rgb images.append(image.squeeze(0).float()* 0.18215) images2.append(image2) masks.append(mask2.squeeze(0)) #masks.append(mask.squeeze(0).squeeze(0).to(image.dtype)) cam_poses.append(c2w) vid_cnt += 1 if vid_cnt == self.opt.num_views: break if vid_cnt < self.opt.num_views: print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') n = self.opt.num_views - vid_cnt images = images + [images[-1]] * n images2 = images2 + [images2[-1]] * n masks = masks + [masks[-1]] * n cam_poses = cam_poses + [cam_poses[-1]] * n images = torch.stack(images, dim=0) # [V, C, H, W] images2 = torch.stack(images2, dim=0) # [V, C, H, W] masks = torch.stack(masks, dim=0) # [V, H, W] # images = torch.randn(self.opt.num_views, 4, 64, 64).to(images.device) # masks = torch.randn(self.opt.num_views, 64, 64).to(masks.device) cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4] radius = torch.norm(cam_poses[0, :3, 3]) cam_poses[:, :3, 3] *= self.opt.cam_radius / radius # normalized camera feats as in paper (transform the first pose to a fixed position) transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4] images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W] cam_poses_input = cam_poses[:self.opt.num_input_views].clone() # data augmentation # if self.training: # # apply random grid distortion to simulate 3D inconsistency # if random.random() < self.opt.prob_grid_distortion: # images_input[1:] = grid_distortion(images_input[1:]) # # apply camera jittering (only to input!) # if random.random() < self.opt.prob_cam_jitter: # cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) # images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) # resize render ground-truth images, range still in [0, 1] results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size] results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(512, 512), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size] results['images2_output'] = F.interpolate(images2, size=(512, 512), mode='bilinear', align_corners=False) # [V, C, output_size, output_size] # build rays for input views rays_embeddings = [] for i in range(self.opt.num_input_views): rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_ray_size, self.opt.input_ray_size, self.opt.fovy) # [h, w, 3] rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] rays_embeddings.append(rays_plucker) rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w] #final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W] #results['input'] = final_input results['input'] = images_input results['ray'] = rays_embeddings # opengl to colmap camera for gaussian renderer cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction # cameras needed by gaussian rasterizer cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] cam_pos = - cam_poses[:, :3, 3] # [V, 3] results['cam_view'] = cam_view results['cam_view_proj'] = cam_view_proj results['cam_pos'] = cam_pos return results if __name__=="__main__": opt = tyro.cli(AllConfigs) GobjaverseDataset(opt, training=True) ================================================ FILE: core/unet_LGM_compos.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Tuple, Literal from functools import partial from core.attention import MemEffAttention, MemEffCrossAttention class MVAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, groups: int = 32, eps: float = 1e-5, residual: bool = True, skip_scale: float = 1, num_frames: int = 4, # WARN: hardcoded! ): super().__init__() self.residual = residual self.skip_scale = skip_scale self.num_frames = num_frames self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) def forward(self, x): # x: [B*V, C, H, W] BV, C, H, W = x.shape B = BV // self.num_frames # assert BV % self.num_frames == 0 res = x x = self.norm(x) x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) x = self.attn(x) x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) if self.residual: x = (x + res) * self.skip_scale return x class UnetAttention(nn.Module): def __init__( self, dim: int, dim_kv: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, groups: int = 32, eps: float = 1e-5, residual: bool = True, #skip_scale: float = 1, num_frames: int = 4, # WARN: hardcoded! ): super().__init__() self.residual = residual self.skip_scale = 1 self.num_frames = num_frames self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) self.attn = MemEffCrossAttention(dim, dim, dim_kv, dim_kv, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) self.post_init() def post_init(self): nn.init.zeros_(self.attn.proj.weight.data) nn.init.zeros_(self.attn.proj.bias.data) def forward(self, x, unet_x): # x: [B*V, C, H, W] BV, C, H, W = x.shape #B = BV // self.num_frames # assert BV % self.num_frames == 0 res = x x = self.norm(x) x = x.permute(0, 2, 3, 1).reshape(BV, -1, C) unet_x = unet_x.permute(0, 2, 3, 1).reshape(BV, H*W, -1) x = self.attn(x, unet_x, unet_x) x = x.reshape(BV, H, W, C).permute(0, 3, 1, 2) if self.residual: x = (x + res) * self.skip_scale return x class ResnetBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, resample: Literal['default', 'up', 'down'] = 'default', groups: int = 32, eps: float = 1e-5, skip_scale: float = 1, # multiplied to output temb_channels: int = 1280, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.skip_scale = skip_scale self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.act = F.silu self.resample = None if resample == 'up': self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") elif resample == 'down': self.resample = nn.AvgPool2d(kernel_size=2, stride=2) self.shortcut = nn.Identity() if self.in_channels != self.out_channels: self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) self.time_emb_proj = nn.Linear(temb_channels, out_channels) self.nolinearity = F.silu def post_init(self): nn.init.zeros_(self.time_emb_proj.weight.data) nn.init.zeros_(self.time_emb_proj.bias.data) def forward(self, x, temb=None): res = x x = self.norm1(x) x = self.act(x) if self.resample: res = self.resample(res) x = self.resample(x) x = self.conv1(x) if temb is not None: temb = self.nolinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] x = x + temb x = self.norm2(x) x = self.act(x) x = self.conv2(x) x = (x + self.shortcut(res)) * self.skip_scale return x class DownBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, unet_out_channels: int, unet_out_next_channels: int, num_layers: int = 1, downsample: bool = True, attention: bool = True, unet_attention: bool = False, attention_heads: int = 16, skip_scale: float = 1, ): super().__init__() nets = [] attns = [] unet_attns = [] self.unet_attention = unet_attention for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) if attention: attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) else: attns.append(None) if unet_attention: unet_attns.append(UnetAttention(out_channels, unet_out_channels)) else: unet_attns.append(None) if unet_attention and downsample: self.down_unet_attns = UnetAttention(out_channels, unet_out_next_channels) self.nets = nn.ModuleList(nets) self.attns = nn.ModuleList(attns) self.unet_attns = nn.ModuleList(unet_attns) self.downsample = None if downsample: self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) def forward(self, x, unet_xs=None, temb=None): xs = [] for attn, unet_attn, net in zip(self.attns, self.unet_attns, self.nets): x = net(x, temb) if attn: x = attn(x) if unet_attn: unet_x = unet_xs[0] unet_xs = unet_xs[1:] x = unet_attn(x, unet_x) xs.append(x) if self.downsample: x = self.downsample(x) if unet_attn: unet_x = unet_xs[0] unet_xs = unet_xs[1:] x = self.down_unet_attns(x, unet_x) xs.append(x) return x, xs class MidBlock(nn.Module): def __init__( self, in_channels: int, num_layers: int = 1, attention: bool = True, attention_heads: int = 16, skip_scale: float = 1, ): super().__init__() nets = [] attns = [] # first layer nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) # more layers for i in range(num_layers): nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) if attention: attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale)) else: attns.append(None) self.nets = nn.ModuleList(nets) self.attns = nn.ModuleList(attns) def forward(self, x, temb=None): x = self.nets[0](x, temb) for attn, net in zip(self.attns, self.nets[1:]): if attn: x = attn(x) x = net(x, temb) return x class UpBlock(nn.Module): def __init__( self, in_channels: int, prev_out_channels: int, out_channels: int, num_layers: int = 1, upsample: bool = True, attention: bool = True, attention_heads: int = 16, skip_scale: float = 1, ): super().__init__() nets = [] attns = [] for i in range(num_layers): cin = in_channels if i == 0 else out_channels cskip = prev_out_channels if (i == num_layers - 1) else out_channels nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) if attention: attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) else: attns.append(None) self.nets = nn.ModuleList(nets) self.attns = nn.ModuleList(attns) self.upsample = None if upsample: self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x, xs, temb=None): for attn, net in zip(self.attns, self.nets): res_x = xs[-1] xs = xs[:-1] x = torch.cat([x, res_x], dim=1) x = net(x, temb) if attn: x = attn(x) if self.upsample: x = F.interpolate(x, scale_factor=2.0, mode='nearest') x = self.upsample(x) return x # it could be asymmetric! class UNet(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), down_unet_channels: Tuple[int, ...] = (320, 320, 320, 640, 1280, 1280), down_attention: Tuple[bool, ...] = (False, False, False, True, True), down_unet_attention : Tuple[bool, ...] = (False, False, True, True, True, True), mid_attention: bool = True, #mid_unet_attention: bool = True, up_channels: Tuple[int, ...] = (1024, 512, 256), #up_unet_channels: Tuple[int, ...] = (1280, 1280, 640, 320, 320), up_attention: Tuple[bool, ...] = (True, True, False), #up_unet_attention: Tuple[bool, ...] = (True, True, True, True, False), #up_last_unet_attention: Tuple[bool, ...] = (False, False, False, True, False), layers_per_block: int = 2, skip_scale: float = np.sqrt(0.5), ): super().__init__() # first self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) # down down_blocks = [] cout = down_channels[0] for i in range(len(down_channels)): cin = cout cout = down_channels[i] unet_cout = down_unet_channels[i] unet_next_cout = down_unet_channels[i+1] if i != len(down_channels) - 1 else down_unet_channels[i] down_blocks.append(DownBlock( cin, cout, unet_cout, unet_next_cout, num_layers=layers_per_block, downsample=(i != len(down_channels) - 1), # not final layer attention=down_attention[i], unet_attention = down_unet_attention[i], skip_scale=skip_scale, )) self.down_blocks = nn.ModuleList(down_blocks) # mid self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) # up up_blocks = [] cout = up_channels[0] for i in range(len(up_channels)): cin = cout cout = up_channels[i] cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric #unet_cout = up_unet_channels[i] up_blocks.append(UpBlock( cin, cskip, cout, num_layers=layers_per_block + 1, # one more layer for up upsample=(i != len(up_channels) - 1), # not final layer attention=up_attention[i], #unet_attention = up_unet_attention[i], #last_unet_attention = up_last_unet_attention[i], skip_scale=skip_scale, )) self.up_blocks = nn.ModuleList(up_blocks) # last self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5) self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x, unet_xss=None, temb=None): # x: [B, Cin, H, W] # first x = self.conv_in(x) # down xss = [x] unet_xss = unet_xss[::-1] for block in self.down_blocks: if block.unet_attention == True: length = len(block.unet_attns) + 1 if block.downsample else len(block.unet_attns) unet_xs = unet_xss[:length] unet_xss = unet_xss[length:] x, xs= block(x, unet_xs, temb) else: x, xs = block(x, temb) xss.extend(xs) # mid # if self.mid_block.unet_attention == True: # unet_xs = unet_xss[0] # unet_xss = unet_xss[1:] # x = self.mid_block(x, unet_xs, temb) #else: x = self.mid_block(x, temb) # up for block in self.up_blocks: xs = xss[-len(block.nets):] xss = xss[:-len(block.nets)] # if block.unet_attention == True: # length = len(block.unet_attns) + 1 if block.upsample else len(block.unet_attns) # unet_xs = unet_xss[:length] # unet_xss = unet_xss[length:] # x = block(x, xs, unet_xs, temb) #else: x = block(x, xs, temb) # last x = self.norm_out(x) x = F.silu(x) x = self.conv_out(x) # [B, Cout, H', W'] return x ================================================ FILE: core/utils.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import roma from kiui.op import safe_normalize def get_rays(pose, h, w, fovy, opengl=True): x, y = torch.meshgrid( torch.arange(w, device=pose.device), torch.arange(h, device=pose.device), indexing="xy", ) x = x.flatten() y = y.flatten() cx = w * 0.5 cy = h * 0.5 focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) camera_dirs = F.pad( torch.stack( [ (x - cx + 0.5) / focal, (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), ], dim=-1, ), (0, 1), value=(-1.0 if opengl else 1.0), ) # [hw, 3] rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] rays_o = rays_o.view(h, w, 3) rays_d = safe_normalize(rays_d).view(h, w, 3) return rays_o, rays_d def orbit_camera_jitter(poses, strength=0.1): # poses: [B, 4, 4], assume orbit camera in opengl format # random orbital rotate B = poses.shape[0] rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) R = rot @ poses[:, :3, :3] T = rot @ poses[:, :3, 3:] new_poses = poses.clone() new_poses[:, :3, :3] = R new_poses[:, :3, 3:] = T return new_poses def grid_distortion(images, strength=0.5): # images: [B, C, H, W] # num_steps: int, grid resolution for distortion # strength: float in [0, 1], strength of distortion B, C, H, W = images.shape num_steps = np.random.randint(8, 17) grid_steps = torch.linspace(-1, 1, num_steps) # have to loop batch... grids = [] for b in range(B): # construct displacement x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb x_steps = (x_steps * W).long() # [num_steps] x_steps[0] = 0 x_steps[-1] = W xs = [] for i in range(num_steps - 1): xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) xs = torch.cat(xs, dim=0) # [W] y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb y_steps = (y_steps * H).long() # [num_steps] y_steps[0] = 0 y_steps[-1] = H ys = [] for i in range(num_steps - 1): ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) ys = torch.cat(ys, dim=0) # [H] # construct grid grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] grids.append(grid) grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] # grid sample images = F.grid_sample(images, grids, align_corners=False) return images ================================================ FILE: infer_ours_masa.py ================================================ import os import tyro import glob import imageio import numpy as np import tqdm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from safetensors.torch import load_file import rembg import kiui from kiui.op import recenter from kiui.cam import orbit_camera from core.options_latents_diffusion import AllConfigs, Options from core.models_LGM_compos_diffusion_validate_inversion_2_masa import LGM import cv2 from mvdream.pipeline_mvdream import MVDreamPipeline #import debugpy; debugpy.connect(("localhost", 5999)) IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) opt = tyro.cli(AllConfigs) opt.weight_dtype = torch.bfloat16 # model model = LGM(opt) # resume pretrained checkpoint if opt.resume is not None: if opt.resume.endswith('safetensors'): ckpt = load_file(opt.resume, device='cpu') else: ckpt = torch.load(opt.resume, map_location='cpu') model.load_state_dict(ckpt, strict=False) print(f'[INFO] Loaded checkpoint from {opt.resume}') else: print(f'[WARN] model randomly initialized, are you sure?') # device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.half().to(device) model.eval() tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) proj_matrix[0, 0] = 1 / tan_half_fov proj_matrix[1, 1] = 1 / tan_half_fov proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) proj_matrix[2, 3] = 1 rays_embeddings, input_cam_view, input_cam_view_proj, input_cam_pos,= model.prepare_default_rays(device, proj_matrix=proj_matrix) # load image dream pipe = MVDreamPipeline.from_pretrained( "/remote-home1/yeyang/aigc/models/models--ashawkey--imagedream-ipmv-diffusers/snapshots/73a034178e748421506492e91790cc62d6aefef5", # remote weights torch_dtype=torch.float16, trust_remote_code=True, # local_files_only=True, ) pipe = pipe.to(device) # load rembg bg_remover = rembg.new_session() # process function def process(opt: Options, path): name = os.path.splitext(os.path.basename(path))[0] print(f'[INFO] Processing {path} --> {name}') os.makedirs(opt.workspace, exist_ok=True) input_image = kiui.read_image(path, mode='uint8') # bg removal carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] mask = carved_image[..., -1] > 0 # carved_image = input_image # mask = carved_image[..., -1] > 0 # recenter image = recenter(carved_image, mask, border_ratio=0.2) # generate mv image = image.astype(np.float32) / 255.0 # rgba to rgb white bg if image.shape[-1] == 4: image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0) # grid = np.concatenate( # [ # np.concatenate([mv_image[0], mv_image[2]], axis=0), # np.concatenate([mv_image[1], mv_image[3]], axis=0), # ], # axis=1, # ) #kiui.write_image(os.path.join(opt.workspace, 'sparrow1.jpg'), image) # image_2 = kiui.read_image('workspace_test_LGM_ours_shoe_masa_cfg3/mv_image.jpg', mode='uint8') # top_left = image_2[:256, :256, :]/255 # top_right = image_2[:256, 256:, :]/255 # bottom_left = image_2[256:, :256, :]/255 # bottom_right = image_2[256:, 256:, :]/255 # mv_image = np.stack([top_left, top_right, bottom_left, bottom_right], axis=0) grid = np.concatenate( [ np.concatenate([mv_image[0], mv_image[2]], axis=0), np.concatenate([mv_image[1], mv_image[3]], axis=0), ], axis=1, ) kiui.write_image(os.path.join(opt.workspace, 'mv_image.jpg'), grid) #kiui.write_image(os.path.join('data_test2', 'helmet.png'), mv_image[1]) #image_2 = cv2.resize(image, (256, 256)) image_2 = cv2.resize(image, (512, 512)) #kiui.write_image(os.path.join(opt.workspace, 'dragon.png'), image_2) image_2 = torch.from_numpy(image_2).unsqueeze(0).permute(0, 3, 1, 2).float().to(device) #mv_image = np.stack([image_2, mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3]], axis=0) ref_image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).float().to(device) ref_image = F.interpolate(ref_image, size=(512, 512), mode='bilinear', align_corners=False) # generate gaussians input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] input_image = F.interpolate(input_image, size=(512, 512), mode='bilinear', align_corners=False) #input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) input_image = torch.cat([image_2, input_image, ref_image]) #input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] with torch.no_grad(): with torch.autocast(device_type='cuda', dtype=torch.bfloat16): # pass # generate gaussians data={} data['images2_output'] = input_image data['cam_view'] = input_cam_view.unsqueeze(0) data['cam_view_proj'] = input_cam_view_proj.unsqueeze(0) data['cam_pos'] = input_cam_pos.unsqueeze(0) data['ray'] = rays_embeddings data['prompt'] = "a photo of a shoe" #gaussians = model.forward_gaussians(input_image) results, gaussians = model.validate(data) # save gaussians model.gs.save_ply(gaussians, os.path.join(opt.workspace, name + '.ply')) #gaussians = model.gs.load_ply("workspace_test_LGM_ours_5/anya_rgba.ply").unsqueeze(0) # render 360 video images = [] elevation = 0 if opt.fancy_video: azimuth = np.arange(0, 720, 4, dtype=np.int32) for azi in tqdm.tqdm(azimuth): cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction # cameras needed by gaussian rasterizer cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] cam_pos = - cam_poses[:, :3, 3] # [V, 3] scale = min(azi / 360, 1) image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) else: azimuth = np.arange(0, 360, 2, dtype=np.int32) for azi in tqdm.tqdm(azimuth): cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction # cameras needed by gaussian rasterizer cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] cam_pos = - cam_poses[:, :3, 3] # [V, 3] image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) imageio.imsave(os.path.join(opt.workspace, f'{azi}' + '.png'), images[-1][0]) images = np.concatenate(images, axis=0) imageio.mimwrite(os.path.join(opt.workspace, name + '.mp4'), images, fps=30) assert opt.test_path is not None if os.path.isdir(opt.test_path): file_paths = glob.glob(os.path.join(opt.test_path, "*")) else: file_paths = [opt.test_path] for path in file_paths: process(opt, path) ================================================ FILE: main_resume_compose.py ================================================ import tyro import time import random import torch from core.options_latents_diffusion import AllConfigs from core.models_LGM_compos_diffusion import LGM from accelerate import Accelerator, DistributedDataParallelKwargs from safetensors.torch import load_file from torch.utils.tensorboard import SummaryWriter import kiui from diffusers.utils.import_utils import is_xformers_available import os import shutil def main(): opt = tyro.cli(AllConfigs) writer = SummaryWriter(f'{opt.workspace}/runs') # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) print(opt.pretrained_model_name_or_path) accelerator = Accelerator( mixed_precision=opt.mixed_precision, gradient_accumulation_steps=opt.gradient_accumulation_steps, # kwargs_handlers=[ddp_kwargs], ) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 opt.mixed_precision = accelerator.mixed_precision opt.weight_dtype = weight_dtype elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 opt.weight_dtype = weight_dtype opt.mixed_precision = accelerator.mixed_precision # model model = LGM(opt) # vae = model.vae # text_encoder = model.text_encoder # text_encoder.requires_grad_(False) # vae.requires_grad_(False) unet = model.unet conv = model.conv unet.requires_grad_(True) conv.requires_grad_(True) unet2 = model.unet2 if opt.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers unet2.enable_xformers_memory_efficient_attention() if opt.gradient_checkpointing: unet2.enable_gradient_checkpointing() # resume if opt.resume is not None: if opt.resume.endswith('safetensors'): ckpt = load_file(opt.resume, device='cpu') else: ckpt = torch.load(opt.resume, map_location='cpu') # tolerant load (only load matching shapes) # model.load_state_dict(ckpt, strict=False) state_dict = model.state_dict() for k, v in ckpt.items(): if k in state_dict: if state_dict[k].shape == v.shape: state_dict[k].copy_(v) else: accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') else: accelerator.print(f'[WARN] unexpected param {k}: {v.shape}') # data if opt.data_mode == 's3': from core.provider_Gobjaverse_latent_diffusion_insert import GobjaverseDataset as Dataset else: raise NotImplementedError train_dataset = Dataset(opt, training=True) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True, drop_last=True, ) test_dataset = Dataset(opt, training=False) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False, ) # if opt.gradient_checkpointing: # model.enable_gradient_checkpointing() #params = [] # for name, param in unet.named_parameters(): # #if name.startswith(tuple(('up_blocks', 'mid_block', 'conv_out'))): # params.append(param) # for name, param in conv.named_parameters(): # params.append(param) params = [] for name, param in model.named_parameters(): if name.startswith('unet.'): #print(name) params.append(param) elif not name.startswith(tuple(('unet2', 'vae', 'tokenizer', 'text_encoder', 'scheduler', 'lpips'))): #print(name) params.append(param) # optimizer optimizer = torch.optim.AdamW(params, lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) # optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) # scheduler (per-iteration) # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3000, eta_min=1e-6) total_steps = opt.num_epochs * len(train_dataloader) pct_start = 3000 / total_steps scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start) # accelerate model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, test_dataloader, scheduler ) # loop for epoch in range(opt.num_epochs): # train model.train() total_loss = 0 total_psnr = 0 for i, data in enumerate(train_dataloader): with accelerator.accumulate(model): optimizer.zero_grad() step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs out = model(data, step_ratio) loss = out['loss'] psnr = out['psnr'] accelerator.backward(loss) writer.add_scalar('loss', loss.item(), epoch*len(train_dataloader)+i) #writer.add_scalar('loss_mse', out['loss_mse'].item(), epoch*len(train_dataloader)+i) writer.add_scalar('loss_mse_image', out['loss_mse_image'].item(), epoch*len(train_dataloader)+i) writer.add_scalar('loss_mse_alpha', out['loss_mse_alpha'].item(), epoch*len(train_dataloader)+i) if step_ratio> 0: writer.add_scalar('loss_lpips', out['loss_lpips'].item(), epoch*len(train_dataloader)+i) writer.add_scalar('psnr', psnr.item(), epoch*len(train_dataloader)+i) writer.add_scalar('lr', scheduler.get_last_lr()[0], epoch*len(train_dataloader)+i) # gradient clipping if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) optimizer.step() scheduler.step() total_loss += loss.detach() total_psnr += psnr.detach() if accelerator.is_main_process: # logging if i % 100 == 0: mem_free, mem_total = torch.cuda.mem_get_info() print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f} loss_mse: {out['loss_mse_image']:.6f}") # save log images if i % 200 == 0: ## FIXME ## 3 ------>4 with torch.no_grad(): # gt_images = (vae.decode(data['images_output'][0, :8].detach().to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5 # gt_images = gt_images.clamp(0,1).float().unsqueeze(0).detach().cpu().numpy() # #gt_images = data['images_output'][:1].detach().cpu().numpy() # [B, V, 3, output_size, output_size] # gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] # kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) gt_alphas = data['masks_output'].clamp(0,1).float().detach().cpu().numpy() # [B, V, 1, output_size, output_size] gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1) kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas) # gt_images_ori = (vae.decode((data['images_output'].detach()*data['masks_output']+out['white_latent'].detach()*(1-data['masks_output']))[0, :8].to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5 # gt_images_ori = gt_images_ori.clamp(0,1).float().unsqueeze(0).detach().cpu().numpy() # gt_images_ori = gt_images_ori.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images_ori.shape[1] * gt_images_ori.shape[3], 3) # [B*output_size, V*output_size, 3] # kiui.write_image(f'{opt.workspace}/train_gt_images_ori_{epoch}_{i}.jpg', gt_images_ori) gt_noise_images = out["gt_noise"].clamp(0,1).float().detach().cpu().numpy() #gt_noise_images = gt_noise_images.transpose(0, 2, 3, 1).reshape(-1, gt_noise_images.shape[2], 3) gt_noise_images = gt_noise_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_noise_images.shape[1] * gt_noise_images.shape[3], 3) kiui.write_image(f'{opt.workspace}/train_gt_noise_images_{epoch}_{i}.jpg', gt_noise_images) gt_images = data['images2_output'].clamp(0,1).float().detach().cpu().numpy() # [B, V, 3, output_size, output_size] gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] # data['images_output'] = (vae.decode(data['images_output'][0, :4].to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5 # gt_images = data['images_output'].clamp(0,1).float().unsqueeze(0).detach().cpu().numpy() #gt_images = data['images_output'][:1].detach().cpu().numpy() # [B, V, 3, output_size, output_size] # gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) # out['images_pred'] = (vae.decode(out['images_pred'][0, :8].detach().to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5 pred_images = out['images_pred'].clamp(0,1).float().detach().cpu().numpy() #pred_images = out['images_pred'].reshape(data['images_output'].shape[0],data['images_output'].shape[1], *out['images_pred'].shape[1:]).detach().cpu().numpy() #pred_images = out['images_pred'][:1].detach().cpu().numpy() # [B, V, 3, output_size, output_size] pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images) # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) # kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas) total_loss = accelerator.gather_for_metrics(total_loss).mean() total_psnr = accelerator.gather_for_metrics(total_psnr).mean() if accelerator.is_main_process: total_loss /= len(train_dataloader) total_psnr /= len(train_dataloader) accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}") # checkpoint if epoch % 1 == 0 or epoch == opt.num_epochs - 1: accelerator.wait_for_everyone() accelerator.save_model(model, opt.workspace) # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` # if opt.checkpoints_total_limit is not None: # checkpoints = os.listdir(opt.workspace) # checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] # checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints # if len(checkpoints) >= opt.checkpoints_total_limit: # num_to_remove = len(checkpoints) - opt.checkpoints_total_limit + 1 # removing_checkpoints = checkpoints[0:num_to_remove] # print( # f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" # ) # print(f"removing checkpoints: {', '.join(removing_checkpoints)}") # for removing_checkpoint in removing_checkpoints: # removing_checkpoint = os.path.join(opt.workspace, removing_checkpoint) # shutil.rmtree(removing_checkpoint) # save_path = os.path.join(opt.workspace, f"checkpoint-{epoch}") # accelerator.save_state(save_path) #print(f"Saved state to {save_path}") # eval with torch.no_grad(): model.eval() total_psnr = 0 for i, data in enumerate(test_dataloader): out = model(data) psnr = out['psnr'] total_psnr += psnr.detach() # save some images if accelerator.is_main_process: gt_images = data['images2_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) pred_images = out['images_pred'].clamp(0,1).float().detach().cpu().numpy() # [B, V, 3, output_size, output_size] pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) # kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas) torch.cuda.empty_cache() total_psnr = accelerator.gather_for_metrics(total_psnr).mean() if accelerator.is_main_process: total_psnr /= len(test_dataloader) accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}") writer.close() if __name__ == "__main__": main() ================================================ FILE: mvdream/mv_unet.py ================================================ import math import numpy as np from inspect import isfunction from typing import Optional, Any, List import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin # require xformers! import xformers import xformers.ops from kiui.cam import orbit_camera def get_camera( num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False, ): angle_gap = azimuth_span / num_frames cameras = [] for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4] # opengl to blender if blender_coord: pose[2] *= -1 pose[[1, 2]] = pose[[2, 1]] cameras.append(pose.flatten()) if extra_view: cameras.append(np.zeros_like(cameras[0])) return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16] def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) else: embedding = repeat(timesteps, "b -> b d", d=dim) # import pdb; pdb.set_trace() return embedding def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def default(val, d): if val is not None: return val return d() if isfunction(d) else d class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = ( nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) ) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, ip_dim=0, ip_weight=1, ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.heads = heads self.dim_head = dim_head self.ip_dim = ip_dim self.ip_weight = ip_weight if self.ip_dim > 0: self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) self.attention_op: Optional[Any] = None def forward(self, x, context=None): q = self.to_q(x) context = default(context, x) if self.ip_dim > 0: # context: [B, 77 + 16(ip), 1024] token_len = context.shape[1] context_ip = context[:, -self.ip_dim :, :] k_ip = self.to_k_ip(context_ip) v_ip = self.to_v_ip(context_ip) context = context[:, : (token_len - self.ip_dim), :] k = self.to_k(context) v = self.to_v(context) b, _, _ = q.shape q, k, v = map( lambda t: t.unsqueeze(3) .reshape(b, t.shape[1], self.heads, self.dim_head) .permute(0, 2, 1, 3) .reshape(b * self.heads, t.shape[1], self.dim_head) .contiguous(), (q, k, v), ) # actually compute the attention, what we cannot get enough of out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op ) if self.ip_dim > 0: k_ip, v_ip = map( lambda t: t.unsqueeze(3) .reshape(b, t.shape[1], self.heads, self.dim_head) .permute(0, 2, 1, 3) .reshape(b * self.heads, t.shape[1], self.dim_head) .contiguous(), (k_ip, v_ip), ) # actually compute the attention, what we cannot get enough of out_ip = xformers.ops.memory_efficient_attention( q, k_ip, v_ip, attn_bias=None, op=self.attention_op ) out = out + self.ip_weight * out_ip out = ( out.unsqueeze(0) .reshape(b, self.heads, out.shape[1], self.dim_head) .permute(0, 2, 1, 3) .reshape(b, out.shape[1], self.heads * self.dim_head) ) return self.to_out(out) class BasicTransformerBlock3D(nn.Module): def __init__( self, dim, n_heads, d_head, context_dim, dropout=0.0, gated_ff=True, ip_dim=0, ip_weight=1, ): super().__init__() self.attn1 = MemoryEfficientCrossAttention( query_dim=dim, context_dim=None, # self-attention heads=n_heads, dim_head=d_head, dropout=dropout, ) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = MemoryEfficientCrossAttention( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, # ip only applies to cross-attention ip_dim=ip_dim, ip_weight=ip_weight, ) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) def forward(self, x, context=None, num_frames=1): x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() x = self.attn1(self.norm1(x), context=None) + x x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x class SpatialTransformer3D(nn.Module): def __init__( self, in_channels, n_heads, d_head, context_dim, # cross attention input dim depth=1, dropout=0.0, ip_dim=0, ip_weight=1, ): super().__init__() if not isinstance(context_dim, list): context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock3D( inner_dim, n_heads, d_head, context_dim=context_dim[d], dropout=dropout, ip_dim=ip_dim, ip_weight=ip_weight, ) for d in range(depth) ] ) self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) def forward(self, x, context=None, num_frames=1): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x x = self.norm(x) x = rearrange(x, "b c h w -> b (h w) c").contiguous() x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i], num_frames=num_frames) x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() return x + x_in class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head ** -0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q, k, v = map( lambda t: t.reshape(b, t.shape[1], self.heads, -1) .transpose(1, 2) .reshape(b, self.heads, t.shape[1], -1) .contiguous(), (q, k, v), ) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class Resampler(nn.Module): def __init__( self, dim=1024, depth=8, dim_head=64, heads=16, num_queries=8, embedding_dim=768, output_dim=1024, ff_mult=4, ): super().__init__() self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * ff_mult, bias=False), nn.GELU(), nn.Linear(dim * ff_mult, dim, bias=False), ) ] ) ) def forward(self, x): latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) class CondSequential(nn.Sequential): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, context=None, num_frames=1): for layer in self: if isinstance(layer, ResBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer3D): x = layer(x, context, num_frames=num_frames) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = conv_nd( dims, self.channels, self.out_channels, 3, padding=padding ) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(nn.Module): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, up=False, down=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( nn.GroupNorm(32, channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), nn.Linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1 ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class MultiViewUNetModel(ModelMixin, ConfigMixin): """ The full multi-view UNet model with attention, timestep embedding and camera embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. :param camera_dim: dimensionality of camera input. """ def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, transformer_depth=1, context_dim=None, n_embed=None, num_attention_blocks=None, adm_in_channels=None, camera_dim=None, ip_dim=0, # imagedream uses ip_dim > 0 ip_weight=1.0, **kwargs, ): super().__init__() assert context_dim is not None if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert ( num_head_channels != -1 ), "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: assert ( num_heads != -1 ), "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): raise ValueError( "provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult" ) self.num_res_blocks = num_res_blocks if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all( map( lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)), ) ) print( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " f"attention will still not be set." ) self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.predict_codebook_ids = n_embed is not None self.ip_dim = ip_dim self.ip_weight = ip_weight if self.ip_dim > 0: self.image_embed = Resampler( dim=context_dim, depth=4, dim_head=64, heads=12, num_queries=ip_dim, # num token embedding_dim=1280, output_dim=context_dim, ff_mult=4, ) time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( nn.Linear(model_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) if camera_dim is not None: time_embed_dim = model_channels * 4 self.camera_embed = nn.Sequential( nn.Linear(camera_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) elif self.num_classes == "continuous": # print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "sequential": assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( nn.Linear(adm_in_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) ) else: raise ValueError() self.input_blocks = nn.ModuleList( [ CondSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers: List[Any] = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if num_attention_blocks is None or nr < num_attention_blocks[level]: layers.append( SpatialTransformer3D( ch, num_heads, dim_head, context_dim=context_dim, depth=transformer_depth, ip_dim=self.ip_dim, ip_weight=self.ip_weight, ) ) self.input_blocks.append(CondSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( CondSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels self.middle_block = CondSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_scale_shift_norm=use_scale_shift_norm, ), SpatialTransformer3D( ch, num_heads, dim_head, context_dim=context_dim, depth=transformer_depth, ip_dim=self.ip_dim, ip_weight=self.ip_weight, ), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult, dims=dims, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if num_attention_blocks is None or i < num_attention_blocks[level]: layers.append( SpatialTransformer3D( ch, num_heads, dim_head, context_dim=context_dim, depth=transformer_depth, ip_dim=self.ip_dim, ip_weight=self.ip_weight, ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(CondSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( nn.GroupNorm(32, ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( nn.GroupNorm(32, ch), conv_nd(dims, model_channels, n_embed, 1), # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) def forward( self, x, timesteps=None, context=None, y=None, camera=None, num_frames=1, ip=None, ip_img=None, **kwargs, ): """ Apply the model to an input batch. :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :param num_frames: a integer indicating number of frames for tensor reshaping. :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). """ assert ( x.shape[0] % num_frames == 0 ), "input batch size must be dividable by num_frames!" assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y is not None assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) # Add camera embeddings if camera is not None: emb = emb + self.camera_embed(camera) # imagedream variant if self.ip_dim > 0: x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9] ip_emb = self.image_embed(ip) context = torch.cat((context, ip_emb), 1) h = x for module in self.input_blocks: h = module(h, emb, context, num_frames=num_frames) hs.append(h) h = self.middle_block(h, emb, context, num_frames=num_frames) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, context, num_frames=num_frames) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h) ================================================ FILE: mvdream/pipeline_mvdream.py ================================================ import torch import torch.nn.functional as F import inspect import numpy as np from typing import Callable, List, Optional, Union from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor from diffusers import AutoencoderKL, DiffusionPipeline from diffusers.utils import ( deprecate, is_accelerate_available, is_accelerate_version, logging, ) from diffusers.configuration_utils import FrozenDict from diffusers.schedulers import DDIMScheduler from diffusers.utils.torch_utils import randn_tensor from mvdream.mv_unet import MultiViewUNetModel, get_camera logger = logging.get_logger(__name__) # pylint: disable=invalid-name class MVDreamPipeline(DiffusionPipeline): _optional_components = ["feature_extractor", "image_encoder"] def __init__( self, vae: AutoencoderKL, unet: MultiViewUNetModel, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, scheduler: DDIMScheduler, # imagedream variant feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModel, requires_safety_checker: bool = False, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate( "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate( "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, unet=unet, scheduler=scheduler, tokenizer=tokenizer, text_encoder=text_encoder, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError( "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" ) device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError( "`enable_model_offload` requires `accelerate v0.17.0` or higher." ) device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook( cpu_offloaded_model, device, prev_module_hook=hook ) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance: bool, negative_prompt=None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError( f"`prompt` should be either a string or a list of strings, but got {type(prompt)}." ) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer( prompt, padding="longest", return_tensors="pt" ).input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if ( hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view( bs_embed * num_images_per_prompt, seq_len, -1 ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if ( hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to( dtype=self.text_encoder.dtype, device=device ) negative_prompt_embeds = negative_prompt_embeds.repeat( 1, num_images_per_prompt, 1 ) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set( inspect.signature(self.scheduler.step).parameters.keys() ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set( inspect.signature(self.scheduler.step).parameters.keys() ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): shape = ( batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor( shape, generator=generator, device=device, dtype=dtype ) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if image.dtype == np.float32: image = (image * 255).astype(np.uint8) image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return torch.zeros_like(image_embeds), image_embeds def encode_image_latents(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W] image = 2 * image - 1 image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) image = image.to(dtype=dtype) posterior = self.vae.encode(image).latent_dist latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) return torch.zeros_like(latents), latents @torch.no_grad() def __call__( self, prompt: str = "", image: Optional[np.ndarray] = None, height: int = 256, width: int = 256, elevation: float = 0, num_inference_steps: int = 50, guidance_scale: float = 7.0, negative_prompt: str = "", num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "numpy", # pil, numpy, latents callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, num_frames: int = 4, device=torch.device("cuda:0"), ): self.unet = self.unet.to(device=device) self.vae = self.vae.to(device=device) self.text_encoder = self.text_encoder.to(device=device) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # imagedream variant if image is not None: assert isinstance(image, np.ndarray) and image.dtype == np.float32 self.image_encoder = self.image_encoder.to(device=device) image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt) image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt) _prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, ) # type: ignore prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2) # Prepare latent variables actual_num_frames = num_frames if image is None else num_frames + 1 latents: torch.Tensor = self.prepare_latents( actual_num_frames * num_images_per_prompt, 4, height, width, prompt_embeds_pos.dtype, device, generator, None, ) if image is not None: camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device) else: camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device) camera = camera.repeat_interleave(num_images_per_prompt, dim=0) # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance multiplier = 2 if do_classifier_free_guidance else 1 latent_model_input = torch.cat([latents] * multiplier) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) unet_inputs = { 'x': latent_model_input, 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device), 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames), 'num_frames': actual_num_frames, 'camera': torch.cat([camera] * multiplier), } if image is not None: unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames) unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat # predict the noise residual noise_pred = self.unet.forward(**unet_inputs) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents: torch.Tensor = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # type: ignore # Post-processing if output_type == "latent": image = latents elif output_type == "pil": image = self.decode_latents(latents) image = self.numpy_to_pil(image) else: # numpy image = self.decode_latents(latents) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return image