[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 PKU-YUAN-Lab (袁粒课题组-北大信工)\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<h2 align=\"center\"> <a href=\"https://github.com/PKU-YuanGroup/Cycle3D\">Cycle3D: High-quality and Consistent Image-to-3D Generation via\nGeneration-Reconstruction Cycle</a></h2>\n<h5 align=\"center\"> If you like our project, please give us a star ⭐ on GitHub for latest update.  </h2>\n\n<h5 align=\"center\">\n\n[![webpage](https://img.shields.io/badge/Webpage-blue)](https://PKU-YuanGroup.github.io/Cycle3D/)\n[![arXiv](https://img.shields.io/badge/Arxiv-2407.19548-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2407.19548)\n[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/PKU-YuanGroup/repaint123/blob/main/LICENSE) \n\n\n</h5>\n\n## [Project page](https://PKU-YuanGroup.github.io/Cycle3D/) | [Paper](https://arxiv.org/abs/2407.19548) | [Live Demo (Coming Soon)]()\n\n\n![image](https://github.com/user-attachments/assets/d6870ef6-4631-4fc2-a2dc-382c054afe0d)\n\n## 😮 Highlights\n\n### 🔥 Generation-Reconstruction cycle for the unified diffusion process\n-  The pre-trained 2D diffusion model trained on billions of web images can generate high-quality texture.\n-  The reconstruction model can ensure consistency across multi-views.\n-  We cyclically utilizes a 2D diffusion-based generation module and a feed-forward 3D reconstruction module during the multi-step diffusion process.\n\n\n\n\n## 🚩 **Updates**\n\nWelcome to **watch** 👀 this repository for the latest updates.\n\n✅ **[2024.7.28]** : We have released our paper, Cycle3D on [arXiv](https://arxiv.org/abs/2407.19548).\n\n✅ **[2024.7.28]** : Release [project page](https://PKU-YuanGroup.github.io/Cycle3D/).\n- [ ] Code release.\n- [ ] Online Demo.\n\n\n## 🤗 Demo\n\nComing soon!\n\n## 🚀 Image-to-3D Results\n\n### Qualitative comparison\n\n![image](https://github.com/user-attachments/assets/ce4f0c0c-793b-4354-b3fa-7d30e97a8ddf)\n\n\n### Quantitative comparison\n\n![image](https://github.com/user-attachments/assets/25a9e1d2-124c-426d-a1a4-54a44aa7d0fc)\n\n\n## 👍 **Acknowledgement**\nThis work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!\n* [LGM](https://github.com/3DTopia/LGM)\n* [MasaCtrl](https://github.com/TencentARC/MasaCtrl)\n* [Diffusers](https://github.com/huggingface/diffusers)\n\n## ✏️ Citation\nIf you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:.\n\n```BibTeX\n@misc{tang2024cycle3dhighqualityconsistentimageto3d,\n      title={Cycle3D: High-quality and Consistent Image-to-3D Generation via Generation-Reconstruction Cycle}, \n      author={Zhenyu Tang and Junwu Zhang and Xinhua Cheng and Wangbo Yu and Chaoran Feng and Yatian Pang and Bin Lin and Li Yuan},\n      year={2024},\n      eprint={2407.19548},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2407.19548}, \n}\n```\n<!---->\n"
  },
  {
    "path": "acc_configs/gpu1.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: 'NO'\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n# distributed_type: DEEPSPEED\n# deepspeed_config:\n#   gradient_clipping: 1.0\n#   zero_stage: 2"
  },
  {
    "path": "acc_configs/gpu4.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 4\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "acc_configs/gpu6.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 6\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n  gradient_clipping: 1.0\n  zero_stage: 2"
  },
  {
    "path": "acc_configs/gpu7.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 7\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n  gradient_clipping: 1.0\n  zero_stage: 2"
  },
  {
    "path": "acc_configs/gpu8.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: MULTI_GPU\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n  gradient_clipping: 1.0\n  zero_stage: 2"
  },
  {
    "path": "acc_configs/hostfile",
    "content": "gpu147 slots=8\ngpu176 slots=8\ngpu47 slots=8\ngpu117 slots=8"
  },
  {
    "path": "acc_configs/multi_node.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndistributed_type: DEEPSPEED\ndeepspeed_config:\n  gradient_clipping: 1.0\n  zero_stage: 2\n  deepspeed_hostfile: /remote-home1/yeyang/aigc/aigc/LGM/acc_configs/hostfile\nfsdp_config: {}\nmachine_rank: 0\nmain_process_ip: 219.223.196.147\nmain_process_port: 29504\nmain_training_function: main\nnum_machines: 4\nnum_processes: 32\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false"
  },
  {
    "path": "acc_configs/zero2.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"train_batch_size\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"overlap_comm\": true,\n        \"contiguous_gradients\": true,\n        \"sub_group_size\": 1e9,\n        \"reduce_bucket_size\": \"auto\"\n    }\n}"
  },
  {
    "path": "core/__init__.py",
    "content": ""
  },
  {
    "path": "core/attention.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nimport os\nimport warnings\n\nfrom torch import Tensor\nfrom torch import nn\n\nXFORMERS_ENABLED = os.environ.get(\"XFORMERS_DISABLED\") is None\ntry:\n    if XFORMERS_ENABLED:\n        from xformers.ops import memory_efficient_attention, unbind\n\n        XFORMERS_AVAILABLE = True\n        warnings.warn(\"xFormers is available (Attention)\")\n    else:\n        warnings.warn(\"xFormers is disabled (Attention)\")\n        raise ImportError\nexcept ImportError:\n    XFORMERS_AVAILABLE = False\n    warnings.warn(\"xFormers is not available (Attention)\")\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]\n        attn = q @ k.transpose(-2, -1)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffAttention(Attention):\n    def forward(self, x: Tensor, attn_bias=None) -> Tensor:\n        if not XFORMERS_AVAILABLE:\n            if attn_bias is not None:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return super().forward(x)\n\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\n\n        q, k, v = unbind(qkv, 2)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        dim_q: int,\n        dim_k: int,\n        dim_v: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)\n        self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)\n        self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:\n        # q: [B, N, Cq]\n        # k: [B, M, Ck]\n        # v: [B, M, Cv]\n        # return: [B, N, C]\n\n        B, N, _ = q.shape\n        M = k.shape[1]\n        \n        q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]\n        k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]\n        v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]\n\n        attn = q @ k.transpose(-2, -1) # [B, nh, N, M]\n\n        attn = attn.softmax(dim=-1) # [B, nh, N, M]\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffCrossAttention(CrossAttention):\n    def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:\n        if not XFORMERS_AVAILABLE:\n            if attn_bias is not None:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return super().forward(x)\n\n        B, N, _ = q.shape\n        M = k.shape[1]\n\n        q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]\n        k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]\n        v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape(B, N, -1)\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n"
  },
  {
    "path": "core/control.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport contextlib\nimport inspect\nfrom typing import Callable, List, Optional, Union, Dict, Any\nimport torchvision.transforms.functional as TF\nimport torch.nn.functional as F\nimport torch\n\nimport PIL\nfrom diffusers.utils import is_accelerate_available\nfrom packaging import version\nfrom tqdm import tqdm\nfrom transformers import (\n    CLIPTextModel,\n    CLIPTokenizer,\n    DPTFeatureExtractor,\n    DPTForDepthEstimation,\n)\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel,  ControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n)\nfrom diffusers.utils import PIL_INTERPOLATION, deprecate, logging\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version\nfrom diffusers.pipelines import StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nimport kiui\n\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\nclass ControlNetPipeline(StableDiffusionControlNetPipeline):\n    \n    def pred_x0(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta: float=0.0,\n        verbose=False,\n    ):\n        \"\"\"\n        predict the sampe the next step in the denoise process.\n        \"\"\"\n        alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x.device)\n        alpha_prod_t = alphas_cumprod [timestep]\n\n        B = alpha_prod_t.shape[0]\n        alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1)\n        beta_prod_t = 1 - alpha_prod_t\n        \n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        return pred_x0\n    \n    def next_step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta=0.,\n        verbose=False\n    ):\n        \"\"\"\n        Inverse sampling for DDIM Inversion\n        \"\"\"\n        if verbose:\n            print(\"timestep: \", timestep)\n        next_step = timestep\n        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod\n        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output\n        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir\n        return x_next, pred_x0\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        data = None,\n        LGM_unet = None, \n        opt = None, \n        pos_act = None,\n        scale_act = None, \n        opacity_act = None,\n        rot_act = None,\n        rgb_act = None,\n        gs = None, \n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,\n                each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,\n                where a list of image lists can be passed to batch for each prompt and each ControlNet.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.\n                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding\n                if `do_classifier_free_guidance` is set to `True`.\n                If not provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeine class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n        self.opt = opt\n        self.pos_act = pos_act\n        self.scale_act = scale_act\n        self.opacity_act = opacity_act\n        self.rot_act = rot_act \n        self.rgb_act = rgb_act\n        self.gs = gs\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        # self.check_inputs(\n        #     prompt,\n        #     image,\n        #     callback_steps,\n        #     negative_prompt,\n        #     prompt_embeds,\n        #     negative_prompt_embeds,\n        #     ip_adapter_image,\n        #     ip_adapter_image_embeds,\n        #     controlnet_conditioning_scale,\n        #     control_guidance_start,\n        #     control_guidance_end,\n        #     callback_on_step_end_tensor_inputs,\n        # )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        if prompt_embeds is not None:\n            prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n                prompt,\n                device,\n                num_images_per_prompt,\n                self.do_classifier_free_guidance,\n                negative_prompt,\n                prompt_embeds=prompt_embeds,\n                negative_prompt_embeds=negative_prompt_embeds,\n                lora_scale=text_encoder_lora_scale,\n                clip_skip=self.clip_skip,\n            )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            # Nested lists as ControlNet condition\n            if isinstance(image[0], list):\n                # Transpose the nested image list\n                image = [list(t) for t in zip(*image)]\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        # 7.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n       \n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infered ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred, blocks_sample, tembpred_noise = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                pred_x0 = self.pred_x0(noise_pred, timestep, latent_model_input)\n                images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5\n                images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False)\n                images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)\n                images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].flatten(0, 1).to(self.opt.weight_dtype) ], dim=1)\n                \n                \n                # perform guidance\n                \n\n                # compute the previous noisy sample x_t -> x_t-1\n                #latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n      \n            \n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            has_nsfw_concept = None\n            #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return images"
  },
  {
    "path": "core/diffuser_utils.py",
    "content": "\"\"\"\nUtil functions based on Diffuser framework.\n\"\"\"\n\n\nimport os\nimport torch\nimport cv2\nimport numpy as np\n\nimport torch.nn.functional as F\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom torchvision.utils import save_image\nfrom torchvision.io import read_image\n\nfrom diffusers import StableDiffusionPipeline\n\nfrom pytorch_lightning import seed_everything\n\n\nclass MasaCtrlPipeline(StableDiffusionPipeline):\n\n    def next_step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta=0.,\n        verbose=False\n    ):\n        \"\"\"\n        Inverse sampling for DDIM Inversion\n        \"\"\"\n        if verbose:\n            print(\"timestep: \", timestep)\n        next_step = timestep\n        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod\n        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output\n        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir\n        return x_next, pred_x0\n\n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta: float=0.0,\n        verbose=False,\n    ):\n        \"\"\"\n        predict the sampe the next step in the denoise process.\n        \"\"\"\n        prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output\n        x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir\n        return x_prev, pred_x0\n\n    @torch.no_grad()\n    def image2latent(self, image):\n        DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n        if type(image) is Image:\n            image = np.array(image)\n            image = torch.from_numpy(image).float() / 127.5 - 1\n            image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)\n        # input image density range [-1, 1]\n        latents = self.vae.encode(image)['latent_dist'].mean\n        latents = latents * 0.18215\n        return latents\n\n    @torch.no_grad()\n    def latent2image(self, latents, return_type='np'):\n        latents = 1 / 0.18215 * latents.detach()\n        image = self.vae.decode(latents)['sample']\n        if return_type == 'np':\n            image = (image / 2 + 0.5).clamp(0, 1)\n            image = image.to(torch.float).cpu().permute(0, 2, 3, 1).numpy()[0]\n            image = (image * 255).astype(np.uint8)\n        elif return_type == \"pt\":\n            image = (image / 2 + 0.5).clamp(0, 1)\n\n        return image\n\n    def latent2image_grad(self, latents):\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents)['sample']\n\n        return image  # range [-1, 1]\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt,\n        batch_size=1,\n        height=512,\n        width=512,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        eta=0.0,\n        latents=None,\n        unconditioning=None,\n        neg_prompt=None,\n        ref_intermediate_latents=None,\n        return_intermediates=False,\n        **kwds):\n        DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n        if isinstance(prompt, list):\n            batch_size = len(prompt)\n        elif isinstance(prompt, str):\n            if batch_size > 1:\n                prompt = [prompt] * batch_size\n\n        # text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            return_tensors=\"pt\"\n        )\n\n        text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]\n        print(\"input text embeddings :\", text_embeddings.shape)\n        if kwds.get(\"dir\"):\n            dir = text_embeddings[-2] - text_embeddings[-1]\n            u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)\n            text_embeddings[-1] = text_embeddings[-1] + kwds.get(\"dir\") * v\n            print(u.shape)\n            print(v.shape)\n\n        # define initial latents\n        latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)\n        if latents is None:\n            latents = torch.randn(latents_shape, device=DEVICE)\n        else:\n            assert latents.shape == latents_shape, f\"The shape of input latent tensor {latents.shape} should equal to predefined one.\"\n\n        # unconditional embedding for classifier free guidance\n        if guidance_scale > 1.:\n            max_length = text_input.input_ids.shape[-1]\n            if neg_prompt:\n                uc_text = neg_prompt\n            else:\n                uc_text = \"\"\n            # uc_text = \"ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face\"\n            unconditional_input = self.tokenizer(\n                [uc_text] * batch_size,\n                padding=\"max_length\",\n                max_length=77,\n                return_tensors=\"pt\"\n            )\n            # unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]\n            unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]\n            text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)\n\n        print(\"latents shape: \", latents.shape)\n        # iterative sampling\n        self.scheduler.set_timesteps(num_inference_steps)\n        # print(\"Valid timesteps: \", reversed(self.scheduler.timesteps))\n        latents_list = [latents]\n        pred_x0_list = [latents]\n        for i, t in enumerate(tqdm(self.scheduler.timesteps, desc=\"DDIM Sampler\")):\n            if ref_intermediate_latents is not None:\n                # note that the batch_size >= 2\n                latents_ref = ref_intermediate_latents[-1 - i]\n                _, latents_cur = latents.chunk(2)\n                latents = torch.cat([latents_ref, latents_cur])\n\n            if guidance_scale > 1.:\n                model_inputs = torch.cat([latents] * 2)\n            else:\n                model_inputs = latents\n            if unconditioning is not None and isinstance(unconditioning, list):\n                _, text_embeddings = text_embeddings.chunk(2)\n                text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings]) \n            # predict tghe noise\n            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample\n            if guidance_scale > 1.:\n                noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)\n                noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)\n            # compute the previous noise sample x_t -> x_t-1\n            latents, pred_x0 = self.step(noise_pred, t, latents)\n            latents_list.append(latents)\n            pred_x0_list.append(pred_x0)\n\n        image = self.latent2image(latents, return_type=\"pt\")\n        if return_intermediates:\n            pred_x0_list = [self.latent2image(img, return_type=\"pt\") for img in pred_x0_list]\n            latents_list = [self.latent2image(img, return_type=\"pt\") for img in latents_list]\n            return image, pred_x0_list, latents_list\n        return image\n\n    @torch.no_grad()\n    def invert(\n        self,\n        image: torch.Tensor,\n        prompt,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        eta=0.0,\n        return_intermediates=False,\n        path = None,\n        **kwds):\n        \"\"\"\n        invert a real image into noise map with determinisc DDIM inversion\n        \"\"\"\n        DEVICE = image.device\n        batch_size = image.shape[0]\n        if isinstance(prompt, list):\n            if batch_size == 1:\n                image = image.expand(len(prompt), -1, -1, -1)\n        elif isinstance(prompt, str):\n            if batch_size > 1:\n                prompt = [prompt] * batch_size\n\n        # text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            return_tensors=\"pt\"\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]\n        print(\"input text embeddings :\", text_embeddings.shape)\n        # define initial latents\n        latents = self.image2latent(image)\n        start_latents = latents\n        # print(latents)\n        # exit()\n        # unconditional embedding for classifier free guidance\n        if guidance_scale > 1.:\n            max_length = text_input.input_ids.shape[-1]\n            unconditional_input = self.tokenizer(\n                [\"\"] * batch_size,\n                padding=\"max_length\",\n                max_length=77,\n                return_tensors=\"pt\"\n            )\n            unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]\n            text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)\n\n        print(\"latents shape: \", latents.shape)\n        # interative sampling\n        self.scheduler.set_timesteps(num_inference_steps)\n        print(\"Valid timesteps: \", reversed(self.scheduler.timesteps))\n        # print(\"attributes: \", self.scheduler.__dict__)\n        latents_list = [latents]\n        pred_x0_list = [latents]\n        for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc=\"DDIM Inversion\")):\n            if guidance_scale > 1.:\n                model_inputs = torch.cat([latents] * 2)\n            else:\n                model_inputs = latents\n\n            # predict the noise\n            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample\n            if guidance_scale > 1.:\n                noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)\n                noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)\n            # compute the previous noise sample x_t-1 -> x_t\n            latents, pred_x0 = self.next_step(noise_pred, t, latents)\n            #Image.fromarray(self.latent2image(latents[:1])).save(os.path.join(path, str(i)+'_8.png'))\n            # if kwds.get(\"workspace\"):\n            #     Image.fromarray(self.latent2image(pred_x0[:1])).save(kwds.get(\"workspace\")+'/'+str(i)+'_8.png')\n            latents_list.append(latents)\n            pred_x0_list.append(pred_x0)\n\n        if return_intermediates:\n            # return the intermediate laters during inversion\n            # pred_x0_list = [self.latent2image(img, return_type=\"pt\") for img in pred_x0_list]\n            return latents, latents_list\n        return latents, start_latents\n"
  },
  {
    "path": "core/gs.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\n\nfrom core.options import Options\n\nimport kiui\n\nclass GaussianRenderer:\n    def __init__(self, opt: Options):\n        \n        self.opt = opt\n        self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device=\"cuda\")\n        \n        # intrinsics\n        self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))\n        self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)\n        self.proj_matrix[0, 0] = 1 / self.tan_half_fov\n        self.proj_matrix[1, 1] = 1 / self.tan_half_fov\n        self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)\n        self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)\n        self.proj_matrix[2, 3] = 1\n        \n    def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):\n        # gaussians: [B, N, 14]\n        # cam_view, cam_view_proj: [B, V, 4, 4]\n        # cam_pos: [B, V, 3]\n\n        device = gaussians.device\n        B, V = cam_view.shape[:2]\n\n        # loop of loop...\n        images = []\n        alphas = []\n        for b in range(B):\n\n            # pos, opacity, scale, rotation, shs\n            means3D = gaussians[b, :, 0:3].contiguous().float()\n            opacity = gaussians[b, :, 3:4].contiguous().float()\n            scales = gaussians[b, :, 4:7].contiguous().float()\n            rotations = gaussians[b, :, 7:11].contiguous().float()\n            rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]\n\n            for v in range(V):\n                \n                # render novel views\n                view_matrix = cam_view[b, v].float()\n                view_proj_matrix = cam_view_proj[b, v].float()\n                campos = cam_pos[b, v].float()\n\n                raster_settings = GaussianRasterizationSettings(\n                    image_height=self.opt.output_size,\n                    image_width=self.opt.output_size,\n                    tanfovx=self.tan_half_fov,\n                    tanfovy=self.tan_half_fov,\n                    bg=self.bg_color if bg_color is None else bg_color,\n                    scale_modifier=scale_modifier,\n                    viewmatrix=view_matrix,\n                    projmatrix=view_proj_matrix,\n                    sh_degree=0,\n                    campos=campos,\n                    prefiltered=False,\n                    debug=False,\n                )\n\n                rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n                # Rasterize visible Gaussians to image, obtain their radii (on screen).\n                rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(\n                    means3D=means3D,\n                    means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),\n                    shs=None,\n                    colors_precomp=rgbs,\n                    opacities=opacity,\n                    scales=scales,\n                    rotations=rotations,\n                    cov3D_precomp=None,\n                )\n\n                rendered_image = rendered_image.clamp(0, 1)\n\n                images.append(rendered_image)\n                alphas.append(rendered_alpha)\n\n        images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)\n        alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)\n\n        return {\n            \"image\": images, # [B, V, 3, H, W]\n            \"alpha\": alphas, # [B, V, 1, H, W]\n        }\n\n\n    def save_ply(self, gaussians, path, compatible=True):\n        # gaussians: [B, N, 14]\n        # compatible: save pre-activated gaussians as in the original paper\n\n        assert gaussians.shape[0] == 1, 'only support batch size 1'\n\n        from plyfile import PlyData, PlyElement\n     \n        means3D = gaussians[0, :, 0:3].contiguous().float()\n        opacity = gaussians[0, :, 3:4].contiguous().float()\n        scales = gaussians[0, :, 4:7].contiguous().float()\n        rotations = gaussians[0, :, 7:11].contiguous().float()\n        shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]\n\n        # prune by opacity\n        mask = opacity.squeeze(-1) >= 0.005\n        means3D = means3D[mask]\n        opacity = opacity[mask]\n        scales = scales[mask]\n        rotations = rotations[mask]\n        shs = shs[mask]\n\n        # invert activation to make it compatible with the original ply format\n        if compatible:\n            opacity = kiui.op.inverse_sigmoid(opacity)\n            scales = torch.log(scales + 1e-8)\n            shs = (shs - 0.5) / 0.28209479177387814\n\n        xyzs = means3D.detach().cpu().numpy()\n        f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        opacities = opacity.detach().cpu().numpy()\n        scales = scales.detach().cpu().numpy()\n        rotations = rotations.detach().cpu().numpy()\n\n        l = ['x', 'y', 'z']\n        # All channels except the 3 DC\n        for i in range(f_dc.shape[1]):\n            l.append('f_dc_{}'.format(i))\n        l.append('opacity')\n        for i in range(scales.shape[1]):\n            l.append('scale_{}'.format(i))\n        for i in range(rotations.shape[1]):\n            l.append('rot_{}'.format(i))\n\n        dtype_full = [(attribute, 'f4') for attribute in l]\n\n        elements = np.empty(xyzs.shape[0], dtype=dtype_full)\n        attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)\n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, 'vertex')\n\n        PlyData([el]).write(path)\n    \n    def load_ply(self, path, compatible=True):\n\n        from plyfile import PlyData, PlyElement\n\n        plydata = PlyData.read(path)\n\n        xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n                        np.asarray(plydata.elements[0][\"y\"]),\n                        np.asarray(plydata.elements[0][\"z\"])),  axis=1)\n        print(\"Number of points at loading : \", xyz.shape[0])\n\n        opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n\n        shs = np.zeros((xyz.shape[0], 3))\n        shs[:, 0] = np.asarray(plydata.elements[0][\"f_dc_0\"])\n        shs[:, 1] = np.asarray(plydata.elements[0][\"f_dc_1\"])\n        shs[:, 2] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n\n        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"scale_\")]\n        scales = np.zeros((xyz.shape[0], len(scale_names)))\n        for idx, attr_name in enumerate(scale_names):\n            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot_\")]\n        rots = np.zeros((xyz.shape[0], len(rot_names)))\n        for idx, attr_name in enumerate(rot_names):\n            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])\n          \n        gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)\n        gaussians = torch.from_numpy(gaussians).float() # cpu\n\n        if compatible:\n            gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])\n            gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])\n            gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5\n\n        return gaussians"
  },
  {
    "path": "core/masactrl.py",
    "content": "import os\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom einops import rearrange\n\nfrom core.masactrl_utils import AttentionBase\n\nfrom torchvision.utils import save_image\n\n\nclass MutualSelfAttentionControl(AttentionBase):\n    MODEL_TYPE = {\n        \"SD\": 16,\n        \"SDXL\": 70\n    }\n\n    def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type=\"SD\"):\n        \"\"\"\n        Mutual self-attention control for Stable-Diffusion model\n        Args:\n            start_step: the step to start mutual self-attention control\n            start_layer: the layer to start mutual self-attention control\n            layer_idx: list of the layers to apply mutual self-attention control\n            step_idx: list the steps to apply mutual self-attention control\n            total_steps: the total number of steps\n            model_type: the model type, SD or SDXL\n        \"\"\"\n        super().__init__()\n        self.total_steps = total_steps\n        self.total_layers = self.MODEL_TYPE.get(model_type, 16)\n        self.start_step = start_step\n        self.start_layer = start_layer\n        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))\n        self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))\n        print(\"MasaCtrl at denoising steps: \", self.step_idx)\n        print(\"MasaCtrl at U-Net layers: \", self.layer_idx)\n\n    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Performing attention for a batch of queries, keys, and values\n        \"\"\"\n        b = q.shape[0] // num_heads\n        q = rearrange(q, \"(b h) n d -> h (b n) d\", h=num_heads)\n        k = rearrange(k, \"(b h) n d -> h (b n) d\", h=num_heads)\n        v = rearrange(v, \"(b h) n d -> h (b n) d\", h=num_heads)\n\n        sim = torch.einsum(\"h i d, h j d -> h i j\", q, k) * kwargs.get(\"scale\")\n        attn = sim.softmax(-1)\n        out = torch.einsum(\"h i j, h j d -> h i d\", attn, v)\n        out = rearrange(out, \"h (b n) d -> b n (h d)\", b=b)\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n        qu, qc = q.chunk(2)\n        ku, kc = k.chunk(2)\n        vu, vc = v.chunk(2)\n        attnu, attnc = attn.chunk(2)\n\n        out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)\n        out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)\n        out = torch.cat([out_u, out_c], dim=0)\n\n        return out\n\nclass MutualSelfAttention3DControl(AttentionBase):\n    MODEL_TYPE = {\n        \"SD\": 16,\n        \"SDXL\": 70\n    }\n\n    def __init__(self, start_steps=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type=\"SD\"):\n        \"\"\"\n        Mutual self-attention control for Stable-Diffusion model\n        Args:\n            start_step: the step to start mutual self-attention control\n            start_layer: the layer to start mutual self-attention control\n            layer_idx: list of the layers to apply mutual self-attention control\n            step_idx: list the steps to apply mutual self-attention control\n            total_steps: the total number of steps\n            model_type: the model type, SD or SDXL\n        \"\"\"\n        super().__init__()\n        self.total_steps = total_steps\n        self.total_layers = self.MODEL_TYPE.get(model_type, 16)\n        self.start_step = start_steps\n        self.start_layer = start_layer\n        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))\n        self.step_idx = step_idx if step_idx is not None else list(range(start_steps, total_steps))\n        print(\"MasaCtrl at denoising steps: \", self.step_idx)\n        print(\"MasaCtrl at U-Net layers: \", self.layer_idx)\n\n    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Performing attention for a batch of queries, keys, and values\n        \"\"\"\n        b = q.shape[0] // num_heads\n        q = rearrange(q, \"(b h) n d -> h (b n) d\", h=num_heads)\n        k = rearrange(k, \"(b h) n d -> h (b n) d\", h=num_heads)\n        v = rearrange(v, \"(b h) n d -> h (b n) d\", h=num_heads)\n\n        sim = torch.einsum(\"h i d, h j d -> h i j\", q, k) * kwargs.get(\"scale\")\n        attn = sim.softmax(-1)\n        out = torch.einsum(\"h i j, h j d -> h i d\", attn, v)\n        out = rearrange(out, \"h (b n) d -> b n (h d)\", b=b)\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n        # qu, qc = q.chunk(2)\n        # ku, kc = k.chunk(2)\n        # vu, vc = v.chunk(2)\n        # attnu, attnc = attn.chunk(2)\n        \n        q_t1, q_t2, q_t3, q_t4, q_s= q.chunk(5)\n        k_t1, k_t2, k_t3, k_t4, k_s = k.chunk(5)\n        v_t1, v_t2, v_t3, v_t4, v_s= v.chunk(5)\n        attn_t1, attn_t2, attn_t3, attn_t4, attn_s= attn.chunk(5)\n        \n        out_s = super().forward(q_s, k_s, v_s, sim, attn_s, is_cross, place_in_unet, num_heads, **kwargs)\n        out_t1 = self.attn_batch(q_t1, torch.cat([k_s, k_t1]), torch.cat([v_s, v_t1]), sim[:num_heads], attn_t1, is_cross, place_in_unet, num_heads, **kwargs)\n        out_t2 = self.attn_batch(q_t2, torch.cat([k_s, k_t2]), torch.cat([v_s, v_t2]), sim[:num_heads], attn_t2, is_cross, place_in_unet, num_heads, **kwargs)\n        out_t3 = self.attn_batch(q_t3, torch.cat([k_s, k_t3]), torch.cat([v_s, v_t3]), sim[:num_heads], attn_t3, is_cross, place_in_unet, num_heads, **kwargs)\n        out_t4 = self.attn_batch(q_t4, torch.cat([k_s, k_t4]), torch.cat([v_s, v_t4]), sim[:num_heads], attn_t4, is_cross, place_in_unet, num_heads, **kwargs)\n        print(1)\n        # out_t1 = self.attn_batch(q_t1, k_s, v_s, sim[:num_heads], attn_t1, is_cross, place_in_unet, num_heads, **kwargs)\n        # out_t2 = self.attn_batch(q_t2, k_s, v_s, sim[:num_heads], attn_t2, is_cross, place_in_unet, num_heads, **kwargs)\n        # out_t3 = self.attn_batch(q_t3, k_s, v_s, sim[:num_heads], attn_t3, is_cross, place_in_unet, num_heads, **kwargs)\n        # out_t4 = self.attn_batch(q_t4, k_s, v_s, sim[:num_heads], attn_t4, is_cross, place_in_unet, num_heads, **kwargs)\n\n        # out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)\n        # out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)\n        out = torch.cat([out_t1, out_t2, out_t3, out_t4, out_s], dim=0)\n\n        return out\n    \nclass MutualSelfAttentionControlUnion(MutualSelfAttentionControl):\n    def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type=\"SD\"):\n        \"\"\"\n        Mutual self-attention control for Stable-Diffusion model with unition source and target [K, V]\n        Args:\n            start_step: the step to start mutual self-attention control\n            start_layer: the layer to start mutual self-attention control\n            layer_idx: list of the layers to apply mutual self-attention control\n            step_idx: list the steps to apply mutual self-attention control\n            total_steps: the total number of steps\n            model_type: the model type, SD or SDXL\n        \"\"\"\n        super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n        qu_s, qu_t, qc_s, qc_t = q.chunk(4)\n        ku_s, ku_t, kc_s, kc_t = k.chunk(4)\n        vu_s, vu_t, vc_s, vc_t = v.chunk(4)\n        attnu_s, attnu_t, attnc_s, attnc_t = attn.chunk(4)\n\n        # source image branch\n        out_u_s = super().forward(qu_s, ku_s, vu_s, sim, attnu_s, is_cross, place_in_unet, num_heads, **kwargs)\n        out_c_s = super().forward(qc_s, kc_s, vc_s, sim, attnc_s, is_cross, place_in_unet, num_heads, **kwargs)\n\n        # target image branch, concatenating source and target [K, V]\n        out_u_t = self.attn_batch(qu_t, torch.cat([ku_s, ku_t]), torch.cat([vu_s, vu_t]), sim[:num_heads], attnu_t, is_cross, place_in_unet, num_heads, **kwargs)\n        out_c_t = self.attn_batch(qc_t, torch.cat([kc_s, kc_t]), torch.cat([vc_s, vc_t]), sim[:num_heads], attnc_t, is_cross, place_in_unet, num_heads, **kwargs)\n\n        out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0)\n\n        return out\n\n\nclass MutualSelfAttentionControlMask(MutualSelfAttentionControl):\n    def __init__(self,  start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50,  mask_s=None, mask_t=None, mask_save_dir=None, model_type=\"SD\"):\n        \"\"\"\n        Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion\n        Args:\n            start_step: the step to start mutual self-attention control\n            start_layer: the layer to start mutual self-attention control\n            layer_idx: list of the layers to apply mutual self-attention control\n            step_idx: list the steps to apply mutual self-attention control\n            total_steps: the total number of steps\n            mask_s: source mask with shape (h, w)\n            mask_t: target mask with same shape as source mask\n            mask_save_dir: the path to save the mask image\n            model_type: the model type, SD or SDXL\n        \"\"\"\n        super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)\n        self.mask_s = mask_s  # source mask with shape (h, w)\n        self.mask_t = mask_t  # target mask with same shape as source mask\n        print(\"Using mask-guided MasaCtrl\")\n        if mask_save_dir is not None:\n            os.makedirs(mask_save_dir, exist_ok=True)\n            save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, \"mask_s.png\"))\n            save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, \"mask_t.png\"))\n\n    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        B = q.shape[0] // num_heads\n        H = W = int(np.sqrt(q.shape[1]))\n        q = rearrange(q, \"(b h) n d -> h (b n) d\", h=num_heads)\n        k = rearrange(k, \"(b h) n d -> h (b n) d\", h=num_heads)\n        v = rearrange(v, \"(b h) n d -> h (b n) d\", h=num_heads)\n\n        sim = torch.einsum(\"h i d, h j d -> h i j\", q, k) * kwargs.get(\"scale\")\n        if kwargs.get(\"is_mask_attn\") and self.mask_s is not None:\n            print(\"masked attention\")\n            mask = self.mask_s.unsqueeze(0).unsqueeze(0)\n            mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0)\n            mask = mask.flatten()\n            # background\n            sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)\n            # object\n            sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)\n            sim = torch.cat([sim_fg, sim_bg], dim=0)\n        attn = sim.softmax(-1)\n        if len(attn) == 2 * len(v):\n            v = torch.cat([v] * 2)\n        out = torch.einsum(\"h i j, h j d -> h i d\", attn, v)\n        out = rearrange(out, \"(h1 h) (b n) d -> (h1 b) n (h d)\", b=B, h=num_heads)\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n        B = q.shape[0] // num_heads // 2\n        H = W = int(np.sqrt(q.shape[1]))\n        qu, qc = q.chunk(2)\n        ku, kc = k.chunk(2)\n        vu, vc = v.chunk(2)\n        attnu, attnc = attn.chunk(2)\n\n        out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)\n        out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)\n\n        out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)\n        out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)\n\n        if self.mask_s is not None and self.mask_t is not None:\n            out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0)\n            out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0)\n\n            mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W))\n            mask = mask.reshape(-1, 1)  # (hw, 1)\n            out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask)\n            out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask)\n\n        out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)\n        return out\n\n\nclass MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):\n    def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None, model_type=\"SD\"):\n        \"\"\"\n        MasaCtrl with mask auto generation from cross-attention map\n        Args:\n            start_step: the step to start mutual self-attention control\n            start_layer: the layer to start mutual self-attention control\n            layer_idx: list of the layers to apply mutual self-attention control\n            step_idx: list the steps to apply mutual self-attention control\n            total_steps: the total number of steps\n            thres: the thereshold for mask thresholding\n            ref_token_idx: the token index list for cross-attention map aggregation\n            cur_token_idx: the token index list for cross-attention map aggregation\n            mask_save_dir: the path to save the mask image\n        \"\"\"\n        super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)\n        print(\"Using MutualSelfAttentionControlMaskAuto\")\n        self.thres = thres\n        self.ref_token_idx = ref_token_idx\n        self.cur_token_idx = cur_token_idx\n\n        self.self_attns = []\n        self.cross_attns = []\n\n        self.cross_attns_mask = None\n        self.self_attns_mask = None\n\n        self.mask_save_dir = mask_save_dir\n        if self.mask_save_dir is not None:\n            os.makedirs(self.mask_save_dir, exist_ok=True)\n\n    def after_step(self):\n        self.self_attns = []\n        self.cross_attns = []\n\n    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Performing attention for a batch of queries, keys, and values\n        \"\"\"\n        B = q.shape[0] // num_heads\n        H = W = int(np.sqrt(q.shape[1]))\n        q = rearrange(q, \"(b h) n d -> h (b n) d\", h=num_heads)\n        k = rearrange(k, \"(b h) n d -> h (b n) d\", h=num_heads)\n        v = rearrange(v, \"(b h) n d -> h (b n) d\", h=num_heads)\n\n        sim = torch.einsum(\"h i d, h j d -> h i j\", q, k) * kwargs.get(\"scale\")\n        if self.self_attns_mask is not None:\n            # binarize the mask\n            mask = self.self_attns_mask\n            thres = self.thres\n            mask[mask >= thres] = 1\n            mask[mask < thres] = 0\n            sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)\n            sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)\n            sim = torch.cat([sim_fg, sim_bg])\n\n        attn = sim.softmax(-1)\n\n        if len(attn) == 2 * len(v):\n            v = torch.cat([v] * 2)\n        out = torch.einsum(\"h i j, h j d -> h i d\", attn, v)\n        out = rearrange(out, \"(h1 h) (b n) d -> (h1 b) n (h d)\", b=B, h=num_heads)\n        return out\n\n    def aggregate_cross_attn_map(self, idx):\n        attn_map = torch.stack(self.cross_attns, dim=1).mean(1)  # (B, N, dim)\n        B = attn_map.shape[0]\n        res = int(np.sqrt(attn_map.shape[-2]))\n        attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1])\n        image = attn_map[..., idx]\n        if isinstance(idx, list):\n            image = image.sum(-1)\n        image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]\n        image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0]\n        image = (image - image_min) / (image_max - image_min)\n        return image\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        if is_cross:\n            # save cross attention map with res 16 * 16\n            if attn.shape[1] == 16 * 16:\n                self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1))\n\n        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n        B = q.shape[0] // num_heads // 2\n        H = W = int(np.sqrt(q.shape[1]))\n        qu, qc = q.chunk(2)\n        ku, kc = k.chunk(2)\n        vu, vc = v.chunk(2)\n        attnu, attnc = attn.chunk(2)\n\n        out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)\n        out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)\n\n        if len(self.cross_attns) == 0:\n            self.self_attns_mask = None\n            out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)\n            out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)\n        else:\n            mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx)  # (2, H, W)\n            mask_source = mask[-2]  # (H, W)\n            res = int(np.sqrt(q.shape[1]))\n            self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten()\n            if self.mask_save_dir is not None:\n                H = W = int(np.sqrt(self.self_attns_mask.shape[0]))\n                mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0)\n                save_image(mask_image, os.path.join(self.mask_save_dir, f\"mask_s_{self.cur_step}_{self.cur_att_layer}.png\"))\n            out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)\n            out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)\n\n        if self.self_attns_mask is not None:\n            mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx)  # (2, H, W)\n            mask_target = mask[-1]  # (H, W)\n            res = int(np.sqrt(q.shape[1]))\n            spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1)\n            if self.mask_save_dir is not None:\n                H = W = int(np.sqrt(spatial_mask.shape[0]))\n                mask_image = spatial_mask.reshape(H, W).unsqueeze(0)\n                save_image(mask_image, os.path.join(self.mask_save_dir, f\"mask_t_{self.cur_step}_{self.cur_att_layer}.png\"))\n            # binarize the mask\n            thres = self.thres\n            spatial_mask[spatial_mask >= thres] = 1\n            spatial_mask[spatial_mask < thres] = 0\n            out_u_target_fg, out_u_target_bg = out_u_target.chunk(2)\n            out_c_target_fg, out_c_target_bg = out_c_target.chunk(2)\n\n            out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask)\n            out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask)\n\n            # set self self-attention mask to None\n            self.self_attns_mask = None\n\n        out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)\n        return out"
  },
  {
    "path": "core/masactrl_utils.py",
    "content": "import os\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom typing import Optional, Union, Tuple, List, Callable, Dict\n\nfrom torchvision.utils import save_image\nfrom einops import rearrange, repeat\n\n\nclass AttentionBase:\n    def __init__(self):\n        self.cur_step = 0\n        self.num_att_layers = -1\n        self.cur_att_layer = 0\n\n    def after_step(self):\n        pass\n\n    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n        self.cur_att_layer += 1\n        if self.cur_att_layer == self.num_att_layers:\n            self.cur_att_layer = 0\n            self.cur_step += 1\n            # after step\n            self.after_step()\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        out = torch.einsum('b i j, b j d -> b i d', attn, v)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)\n        return out\n\n    def reset(self):\n        self.cur_step = 0\n        self.cur_att_layer = 0\n\n\nclass AttentionStore(AttentionBase):\n    def __init__(self, res=[32], min_step=0, max_step=1000):\n        super().__init__()\n        self.res = res\n        self.min_step = min_step\n        self.max_step = max_step\n        self.valid_steps = 0\n\n        self.self_attns = []  # store the all attns\n        self.cross_attns = []\n\n        self.self_attns_step = []  # store the attns in each step\n        self.cross_attns_step = []\n\n    def after_step(self):\n        if self.cur_step > self.min_step and self.cur_step < self.max_step:\n            self.valid_steps += 1\n            if len(self.self_attns) == 0:\n                self.self_attns = self.self_attns_step\n                self.cross_attns = self.cross_attns_step\n            else:\n                for i in range(len(self.self_attns)):\n                    self.self_attns[i] += self.self_attns_step[i]\n                    self.cross_attns[i] += self.cross_attns_step[i]\n        self.self_attns_step.clear()\n        self.cross_attns_step.clear()\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        if attn.shape[1] <= 64 ** 2:  # avoid OOM\n            if is_cross:\n                self.cross_attns_step.append(attn)\n            else:\n                self.self_attns_step.append(attn)\n        return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n\n\ndef regiter_attention_editor_diffusers(unet, editor: AttentionBase):\n    \"\"\"\n    Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]\n    \"\"\"\n    def ca_forward(self, place_in_unet):\n        def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):\n            \"\"\"\n            The attention is similar to the original implementation of LDM CrossAttention class\n            except adding some modifications on the attention\n            \"\"\"\n            if encoder_hidden_states is not None:\n                context = encoder_hidden_states\n            if attention_mask is not None:\n                mask = attention_mask\n\n            to_out = self.to_out\n            if isinstance(to_out, nn.modules.container.ModuleList):\n                to_out = self.to_out[0]\n            else:\n                to_out = self.to_out\n\n            h = self.heads\n            q = self.to_q(x)\n            is_cross = context is not None\n            context = context if is_cross else x\n            k = self.to_k(context)\n            v = self.to_v(context)\n            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))\n\n            sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale\n\n            if mask is not None:\n                mask = rearrange(mask, 'b ... -> b (...)')\n                max_neg_value = -torch.finfo(sim.dtype).max\n                mask = repeat(mask, 'b j -> (b h) () j', h=h)\n                mask = mask[:, None, :].repeat(h, 1, 1)\n                sim.masked_fill_(~mask, max_neg_value)\n\n            attn = sim.softmax(dim=-1)\n            # the only difference\n            out = editor(\n                q, k, v, sim, attn, is_cross, place_in_unet,\n                self.heads, scale=self.scale)\n\n            return to_out(out)\n\n        return forward\n\n    def register_editor(net, count, place_in_unet):\n        for name, subnet in net.named_children():\n            if net.__class__.__name__ == 'Attention':  # spatial Transformer layer\n                net.forward = ca_forward(net, place_in_unet)\n                return count + 1\n            elif hasattr(net, 'children'):\n                count = register_editor(subnet, count, place_in_unet)\n        return count\n\n    cross_att_count = 0\n    for net_name, net in unet.named_children():\n        if \"down\" in net_name:\n            cross_att_count += register_editor(net, 0, \"down\")\n        elif \"mid\" in net_name:\n            cross_att_count += register_editor(net, 0, \"mid\")\n        elif \"up\" in net_name:\n            cross_att_count += register_editor(net, 0, \"up\")\n    editor.num_att_layers = cross_att_count\n\n\ndef regiter_attention_editor_ldm(model, editor: AttentionBase):\n    \"\"\"\n    Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]\n    \"\"\"\n    def ca_forward(self, place_in_unet):\n        def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):\n            \"\"\"\n            The attention is similar to the original implementation of LDM CrossAttention class\n            except adding some modifications on the attention\n            \"\"\"\n            if encoder_hidden_states is not None:\n                context = encoder_hidden_states\n            if attention_mask is not None:\n                mask = attention_mask\n\n            to_out = self.to_out\n            if isinstance(to_out, nn.modules.container.ModuleList):\n                to_out = self.to_out[0]\n            else:\n                to_out = self.to_out\n\n            h = self.heads\n            q = self.to_q(x)\n            is_cross = context is not None\n            context = context if is_cross else x\n            k = self.to_k(context)\n            v = self.to_v(context)\n            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))\n\n            sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale\n\n            if mask is not None:\n                mask = rearrange(mask, 'b ... -> b (...)')\n                max_neg_value = -torch.finfo(sim.dtype).max\n                mask = repeat(mask, 'b j -> (b h) () j', h=h)\n                mask = mask[:, None, :].repeat(h, 1, 1)\n                sim.masked_fill_(~mask, max_neg_value)\n\n            attn = sim.softmax(dim=-1)\n            # the only difference\n            out = editor(\n                q, k, v, sim, attn, is_cross, place_in_unet,\n                self.heads, scale=self.scale)\n\n            return to_out(out)\n\n        return forward\n\n    def register_editor(net, count, place_in_unet):\n        for name, subnet in net.named_children():\n            if net.__class__.__name__ == 'CrossAttention':  # spatial Transformer layer\n                net.forward = ca_forward(net, place_in_unet)\n                return count + 1\n            elif hasattr(net, 'children'):\n                count = register_editor(subnet, count, place_in_unet)\n        return count\n\n    cross_att_count = 0\n    for net_name, net in model.model.diffusion_model.named_children():\n        if \"input\" in net_name:\n            cross_att_count += register_editor(net, 0, \"input\")\n        elif \"middle\" in net_name:\n            cross_att_count += register_editor(net, 0, \"middle\")\n        elif \"output\" in net_name:\n            cross_att_count += register_editor(net, 0, \"output\")\n    editor.num_att_layers = cross_att_count"
  },
  {
    "path": "core/models/__init__.py",
    "content": ""
  },
  {
    "path": "core/models/transformer_mv2d.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.embeddings import ImagePositionalEmbeddings\nfrom diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph\nfrom diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention\nfrom diffusers.models.embeddings import PatchEmbed\nfrom diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils.import_utils import is_xformers_available\n\nfrom einops import rearrange, repeat\nimport pdb\nimport random\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 0:\n        return nn.Linear(*args, **kwargs)\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\nif is_xformers_available():\n    import xformers\n    import xformers.ops\nelse:\n    xformers = None\n\ndef my_repeat(tensor, num_repeats):\n    \"\"\"\n    Repeat a tensor along a given dimension\n    \"\"\"\n    if len(tensor.shape) == 3:\n        return repeat(tensor,  \"b d c -> (b v) d c\", v=num_repeats)\n    elif len(tensor.shape) == 4:\n        return repeat(tensor,  \"a b d c -> (a v) b d c\", v=num_repeats)\n\n\n@dataclass\nclass TransformerMV2DModelOutput(BaseOutput):\n    \"\"\"\n    The output of [`Transformer2DModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):\n            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability\n            distributions for the unnoised latent pixels.\n    \"\"\"\n\n    sample: torch.FloatTensor\n\n\nclass TransformerMV2DModel(ModelMixin, ConfigMixin):\n    \"\"\"\n    A 2D Transformer model for image-like data.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            The number of channels in the input and output (specify if the input is **continuous**).\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.\n        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).\n            This is fixed during training since it is used to learn a number of position embeddings.\n        num_vector_embeds (`int`, *optional*):\n            The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).\n            Includes the class for the masked latent pixel.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to use in feed-forward.\n        num_embeds_ada_norm ( `int`, *optional*):\n            The number of diffusion steps used during training. Pass if at least one of the norm_layers is\n            `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are\n            added to the hidden states.\n\n            During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.\n        attention_bias (`bool`, *optional*):\n            Configure if the `TransformerBlocks` attention should contain a bias parameter.\n    \"\"\"\n\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        out_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        sample_size: Optional[int] = None,\n        num_vector_embeds: Optional[int] = None,\n        patch_size: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_type: str = \"layer_norm\",\n        norm_elementwise_affine: bool = True,\n        num_views: int = 1,\n        cd_attention_last: bool=False,\n        cd_attention_mid: bool=False,\n        multiview_attention: bool=True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool=False\n    ):\n        super().__init__()\n        self.use_linear_projection = use_linear_projection\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        inner_dim = num_attention_heads * attention_head_dim\n\n        # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`\n        # Define whether input is continuous or discrete depending on configuration\n        self.is_input_continuous = (in_channels is not None) and (patch_size is None)\n        self.is_input_vectorized = num_vector_embeds is not None\n        self.is_input_patches = in_channels is not None and patch_size is not None\n\n        if norm_type == \"layer_norm\" and num_embeds_ada_norm is not None:\n            deprecation_message = (\n                f\"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or\"\n                \" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config.\"\n                \" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect\"\n                \" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it\"\n                \" would be very nice if you could open a Pull request for the `transformer/config.json` file\"\n            )\n            deprecate(\"norm_type!=num_embeds_ada_norm\", \"1.0.0\", deprecation_message, standard_warn=False)\n            norm_type = \"ada_norm\"\n\n        if self.is_input_continuous and self.is_input_vectorized:\n            raise ValueError(\n                f\"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make\"\n                \" sure that either `in_channels` or `num_vector_embeds` is None.\"\n            )\n        elif self.is_input_vectorized and self.is_input_patches:\n            raise ValueError(\n                f\"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make\"\n                \" sure that either `num_vector_embeds` or `num_patches` is None.\"\n            )\n        elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:\n            raise ValueError(\n                f\"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:\"\n                f\" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None.\"\n            )\n\n        # 2. Define input layers\n        if self.is_input_continuous:\n            self.in_channels = in_channels\n\n            self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n            if use_linear_projection:\n                self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)\n            else:\n                self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)\n        elif self.is_input_vectorized:\n            assert sample_size is not None, \"Transformer2DModel over discrete input must provide sample_size\"\n            assert num_vector_embeds is not None, \"Transformer2DModel over discrete input must provide num_embed\"\n\n            self.height = sample_size\n            self.width = sample_size\n            self.num_vector_embeds = num_vector_embeds\n            self.num_latent_pixels = self.height * self.width\n\n            self.latent_image_embedding = ImagePositionalEmbeddings(\n                num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width\n            )\n        elif self.is_input_patches:\n            assert sample_size is not None, \"Transformer2DModel over patched input must provide sample_size\"\n\n            self.height = sample_size\n            self.width = sample_size\n\n            self.patch_size = patch_size\n            self.pos_embed = PatchEmbed(\n                height=sample_size,\n                width=sample_size,\n                patch_size=patch_size,\n                in_channels=in_channels,\n                embed_dim=inner_dim,\n            )\n\n        # 3. Define transformers blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicMVTransformerBlock(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    dropout=dropout,\n                    cross_attention_dim=cross_attention_dim,\n                    activation_fn=activation_fn,\n                    num_embeds_ada_norm=num_embeds_ada_norm,\n                    attention_bias=attention_bias,\n                    only_cross_attention=only_cross_attention,\n                    upcast_attention=upcast_attention,\n                    norm_type=norm_type,\n                    norm_elementwise_affine=norm_elementwise_affine,\n                    num_views=num_views,\n                    cd_attention_last=cd_attention_last,\n                    cd_attention_mid=cd_attention_mid,\n                    multiview_attention=multiview_attention,\n                    sparse_mv_attention=sparse_mv_attention,\n                    mvcd_attention=mvcd_attention\n                )\n                for d in range(num_layers)\n            ]\n        )\n\n        # 4. Define output layers\n        self.out_channels = in_channels if out_channels is None else out_channels\n        if self.is_input_continuous:\n            # TODO: should use out_channels for continuous projections\n            if use_linear_projection:\n                self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)\n            else:\n                self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)\n        elif self.is_input_vectorized:\n            self.norm_out = nn.LayerNorm(inner_dim)\n            self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)\n        elif self.is_input_patches:\n            self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)\n            self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)\n            self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)\n        \n        self.post_init()\n    \n    def post_init(self):\n        conv_block = self.proj_in\n        conv_params = {\n            k: getattr(conv_block, k)\n            for k in [\n                \"in_channels\",\n                \"out_channels\",\n                \"kernel_size\",\n                \"stride\",\n                \"padding\",\n            ]\n        }\n        conv_params[\"in_channels\"] += 6\n        conv_params[\"dims\"] = 2\n        conv_params[\"device\"] = conv_block.weight.device\n        inflated_proj_in = conv_nd(**conv_params)\n        inp_weight = conv_block.weight.data\n        feat_shape = inp_weight.shape\n        feat_weight = torch.zeros(\n            (feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device\n        )\n        inflated_proj_in.weight.data.copy_(\n            torch.cat([inp_weight, feat_weight], dim=1)\n        )\n        inflated_proj_in.bias.data.copy_(conv_block.bias.data)\n        self.proj_in = inflated_proj_in\n        self.post_intialized = True\n    \n    def post_linear_init(self):\n        linear_block = self.proj_in\n        linear_params = {\n            k: getattr(linear_block, k)\n            for k in [\n                \"in_features\",\n                \"out_features\"\n            ]\n        }\n        linear_params[\"in_features\"] += 6\n        linear_params[\"dims\"] = 0\n        linear_params[\"device\"] = linear_block.weight.device\n        inflated_proj_in = conv_nd(**linear_params)\n        inp_weight = linear_block.weight.data\n        feat_shape = inp_weight.shape\n        feat_weight = torch.zeros(\n            (feat_shape[0], 6), device=inp_weight.device\n        )\n        inflated_proj_in.weight.data.copy_(\n            torch.cat([inp_weight, feat_weight], dim=1)\n        )\n        inflated_proj_in.bias.data.copy_(linear_block.bias.data)\n        self.proj_in = inflated_proj_in\n        self.post_intialized = True\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        class_labels: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        ray_embedding: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ):\n        \"\"\"\n        The [`Transformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):\n                Input `hidden_states`.\n            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.LongTensor`, *optional*):\n                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.\n            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):\n                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in\n                `AdaLayerZeroNorm`.\n            encoder_attention_mask ( `torch.Tensor`, *optional*):\n                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:\n\n                    * Mask `(batch, sequence_length)` True = keep, False = discard.\n                    * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.\n\n                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format\n                above. This bias will be added to the cross-attention scores.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a\n            `tuple` where the first element is the sample tensor.\n        \"\"\"\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.\n        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.\n        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None and attention_mask.ndim == 2:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 1. Input\n        if self.is_input_continuous:\n            batch, _, height, width = hidden_states.shape\n            residual = hidden_states\n\n            hidden_states = self.norm(hidden_states)\n            if self.post_intialized:\n                #ray_embedding = rearrange(ray_embedding, \"n v c h w -> (n v) c h w\")\n                ray_embedding_interpolated = F.interpolate(ray_embedding, size=hidden_states.shape[-2:], align_corners=False, mode=\"bilinear\")\n                #ray_embedding_interpolated = rearrange(ray_embedding_interpolated, \"(n v) c h w -> n v c h w\", v=4)\n\n            # concat plucker to x\n                hidden_states = torch.cat([hidden_states, ray_embedding_interpolated], dim=1)\n                #hidden_states = rearrange(hidden_states, \"n v c h w -> (n v) c h w\")\n                # x = self.proj_in(x)\n                # x = rearrange(x, \"(n v) c h w -> n v c h w\", v=4)\n\n            if not self.use_linear_projection:\n                hidden_states = self.proj_in(hidden_states)\n                inner_dim = hidden_states.shape[1]\n                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)\n            else:\n                inner_dim = hidden_states.shape[1]\n                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)\n                hidden_states = self.proj_in(hidden_states)\n                inner_dim = inner_dim -6 \n                \n        elif self.is_input_vectorized:\n            hidden_states = self.latent_image_embedding(hidden_states)\n        elif self.is_input_patches:\n            hidden_states = self.pos_embed(hidden_states)\n\n        # 2. Blocks\n        for block in self.transformer_blocks:\n            hidden_states = block(\n                hidden_states,\n                attention_mask=attention_mask,\n                encoder_hidden_states=encoder_hidden_states,\n                encoder_attention_mask=encoder_attention_mask,\n                timestep=timestep,\n                cross_attention_kwargs=cross_attention_kwargs,\n                class_labels=class_labels,\n            )\n\n        # 3. Output\n        if self.is_input_continuous:\n            if not self.use_linear_projection:\n                hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n                hidden_states = self.proj_out(hidden_states)\n            else:\n                hidden_states = self.proj_out(hidden_states)\n                hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n\n            output = hidden_states + residual\n        elif self.is_input_vectorized:\n            hidden_states = self.norm_out(hidden_states)\n            logits = self.out(hidden_states)\n            # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)\n            logits = logits.permute(0, 2, 1)\n\n            # log(p(x_0))\n            output = F.log_softmax(logits.double(), dim=1).float()\n        elif self.is_input_patches:\n            # TODO: cleanup!\n            conditioning = self.transformer_blocks[0].norm1.emb(\n                timestep, class_labels, hidden_dtype=hidden_states.dtype\n            )\n            shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)\n            hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]\n            hidden_states = self.proj_out_2(hidden_states)\n\n            # unpatchify\n            height = width = int(hidden_states.shape[1] ** 0.5)\n            hidden_states = hidden_states.reshape(\n                shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)\n            )\n            hidden_states = torch.einsum(\"nhwpqc->nchpwq\", hidden_states)\n            output = hidden_states.reshape(\n                shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)\n            )\n\n        if not return_dict:\n            return (output,)\n\n        return TransformerMV2DModelOutput(sample=output)\n\n\n@maybe_allow_in_graph\nclass BasicMVTransformerBlock(nn.Module):\n    r\"\"\"\n    A basic Transformer block.\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.\n        only_cross_attention (`bool`, *optional*):\n            Whether to use only cross-attention layers. In this case two cross attention layers are used.\n        double_self_attention (`bool`, *optional*):\n            Whether to use two self-attention layers. In this case no cross attention layers are used.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm (:\n            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.\n        attention_bias (:\n            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout=0.0,\n        cross_attention_dim: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        attention_bias: bool = False,\n        only_cross_attention: bool = False,\n        double_self_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_elementwise_affine: bool = True,\n        norm_type: str = \"layer_norm\",\n        final_dropout: bool = False,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool = False\n    ):\n        super().__init__()\n        self.only_cross_attention = only_cross_attention\n\n        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == \"ada_norm_zero\"\n        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == \"ada_norm\"\n\n        if norm_type in (\"ada_norm\", \"ada_norm_zero\") and num_embeds_ada_norm is None:\n            raise ValueError(\n                f\"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to\"\n                f\" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}.\"\n            )\n\n        # Define 3 blocks. Each block has its own normalization layer.\n        # 1. Self-Attn\n        if self.use_ada_layer_norm:\n            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)\n        elif self.use_ada_layer_norm_zero:\n            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)\n        else:\n            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)\n\n        self.multiview_attention = multiview_attention\n        self.sparse_mv_attention = sparse_mv_attention\n        self.mvcd_attention = mvcd_attention\n        \n        self.attn1 = CustomAttention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            cross_attention_dim=cross_attention_dim if only_cross_attention else None,\n            upcast_attention=upcast_attention,\n            processor=MVAttnProcessor()\n        )\n\n        # 2. Cross-Attn\n        if cross_attention_dim is not None or double_self_attention:\n            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.\n            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during\n            # the second cross attention block.\n            self.norm2 = (\n                AdaLayerNorm(dim, num_embeds_ada_norm)\n                if self.use_ada_layer_norm\n                else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)\n            )\n            self.attn2 = Attention(\n                query_dim=dim,\n                cross_attention_dim=cross_attention_dim if not double_self_attention else None,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                upcast_attention=upcast_attention,\n            )  # is self-attn if encoder_hidden_states is none\n        else:\n            self.norm2 = None\n            self.attn2 = None\n\n        # 3. Feed-forward\n        self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)\n        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)\n\n        # let chunk size default to None\n        self._chunk_size = None\n        self._chunk_dim = 0\n\n        self.num_views = num_views\n\n        self.cd_attention_last = cd_attention_last\n\n        if self.cd_attention_last:\n            # Joint task -Attn\n            self.attn_joint_last = CustomJointAttention(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                cross_attention_dim=cross_attention_dim if only_cross_attention else None,\n                upcast_attention=upcast_attention,\n                processor=JointAttnProcessor()\n            )\n            nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)\n            self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n\n\n        self.cd_attention_mid = cd_attention_mid\n\n        if self.cd_attention_mid:\n            # print(\"cross-domain attn in the middle\")\n            # Joint task -Attn\n            self.attn_joint_mid = CustomJointAttention(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                cross_attention_dim=cross_attention_dim if only_cross_attention else None,\n                upcast_attention=upcast_attention,\n                processor=JointAttnProcessor()\n            )\n            nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)\n            self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)\n\n    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):\n        # Sets chunk feed-forward\n        self._chunk_size = chunk_size\n        self._chunk_dim = dim\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        class_labels: Optional[torch.LongTensor] = None,\n    ):\n        assert attention_mask is None # not supported yet\n        # Notice that normalization is always applied before the real computation in the following blocks.\n        # 1. Self-Attention\n        if self.use_ada_layer_norm:\n            norm_hidden_states = self.norm1(hidden_states, timestep)\n        elif self.use_ada_layer_norm_zero:\n            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n            )\n        else:\n            norm_hidden_states = self.norm1(hidden_states)\n\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        attn_output = self.attn1(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n            attention_mask=attention_mask,\n            num_views=self.num_views,\n            multiview_attention=self.multiview_attention,\n            sparse_mv_attention=self.sparse_mv_attention,\n            mvcd_attention=self.mvcd_attention,\n            **cross_attention_kwargs,\n        )\n\n\n        if self.use_ada_layer_norm_zero:\n            attn_output = gate_msa.unsqueeze(1) * attn_output\n        hidden_states = attn_output + hidden_states\n\n        # joint attention twice\n        if self.cd_attention_mid:\n            norm_hidden_states = (\n                self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)\n            )\n            hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states\n\n        # 2. Cross-Attention\n        if self.attn2 is not None:\n            norm_hidden_states = (\n                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n            )\n\n            attn_output = self.attn2(\n                norm_hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=encoder_attention_mask,\n                **cross_attention_kwargs,\n            )\n            hidden_states = attn_output + hidden_states\n\n        # 3. Feed-forward\n        norm_hidden_states = self.norm3(hidden_states)\n\n        if self.use_ada_layer_norm_zero:\n            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n        if self._chunk_size is not None:\n            # \"feed_forward_chunk_size\" can be used to save memory\n            if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:\n                raise ValueError(\n                    f\"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.\"\n                )\n\n            num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size\n            ff_output = torch.cat(\n                [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],\n                dim=self._chunk_dim,\n            )\n        else:\n            ff_output = self.ff(norm_hidden_states)\n\n        if self.use_ada_layer_norm_zero:\n            ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n        hidden_states = ff_output + hidden_states\n\n        if self.cd_attention_last:\n            norm_hidden_states = (\n                self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)\n            )\n            hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states\n\n        return hidden_states\n    \n\nclass CustomAttention(Attention):\n    def set_use_memory_efficient_attention_xformers(\n        self, use_memory_efficient_attention_xformers: bool, *args, **kwargs\n    ):\n        processor = XFormersMVAttnProcessor()\n        self.set_processor(processor)\n        # print(\"using xformers attention processor\")\n\n\nclass CustomJointAttention(Attention):\n    def set_use_memory_efficient_attention_xformers(\n        self, use_memory_efficient_attention_xformers: bool, *args, **kwargs\n    ):\n        processor = XFormersJointAttnProcessor()\n        self.set_processor(processor)\n        # print(\"using xformers attention processor\")\n\nclass MVAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        num_views=1,\n        multiview_attention=True\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        # print('query', query.shape, 'key', key.shape, 'value', value.shape)\n        #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])\n        # pdb.set_trace()\n        # multi-view self-attention\n        if multiview_attention:\n            key = rearrange(key, \"(b t) d c -> b (t d) c\", t=num_views).repeat_interleave(num_views, dim=0)\n            value = rearrange(value, \"(b t) d c -> b (t d) c\", t=num_views).repeat_interleave(num_views, dim=0)\n\n        query = attn.head_to_batch_dim(query).contiguous()\n        key = attn.head_to_batch_dim(key).contiguous()\n        value = attn.head_to_batch_dim(value).contiguous()\n        \n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n        \n        return hidden_states\n\n\nclass XFormersMVAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        num_views=1.,\n        multiview_attention=True,\n        sparse_mv_attention=False,\n        mvcd_attention=False,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        # from yuancheng; here attention_mask is None\n        if attention_mask is not None:\n            # expand our mask's singleton query_tokens dimension:\n            #   [batch*heads,            1, key_tokens] ->\n            #   [batch*heads, query_tokens, key_tokens]\n            # so that it can be added as a bias onto the attention scores that xformers computes:\n            #   [batch*heads, query_tokens, key_tokens]\n            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.\n            _, query_tokens, _ = hidden_states.shape\n            attention_mask = attention_mask.expand(-1, query_tokens, -1)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key_raw = attn.to_k(encoder_hidden_states)\n        value_raw = attn.to_v(encoder_hidden_states)\n\n        # print('query', query.shape, 'key', key.shape, 'value', value.shape)\n        #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])\n        # pdb.set_trace()\n        # multi-view self-attention\n        if multiview_attention:\n            if not sparse_mv_attention:\n                key = my_repeat(rearrange(key_raw, \"(b t) d c -> b (t d) c\", t=num_views), num_views)\n                value = my_repeat(rearrange(value_raw, \"(b t) d c -> b (t d) c\", t=num_views), num_views)\n            else:\n                key_front = my_repeat(rearrange(key_raw, \"(b t) d c -> b t d c\", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]\n                value_front = my_repeat(rearrange(value_raw, \"(b t) d c -> b t d c\", t=num_views)[:, 0, :, :], num_views)\n                key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c\n                value = torch.cat([value_front, value_raw], dim=1)\n\n        else:\n            # print(\"don't use multiview attention.\")\n            key = key_raw\n            value = value_raw\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n        \n        return hidden_states\n\n\n\nclass XFormersJointAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        num_tasks=2\n    ):\n        \n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        # from yuancheng; here attention_mask is None\n        if attention_mask is not None:\n            # expand our mask's singleton query_tokens dimension:\n            #   [batch*heads,            1, key_tokens] ->\n            #   [batch*heads, query_tokens, key_tokens]\n            # so that it can be added as a bias onto the attention scores that xformers computes:\n            #   [batch*heads, query_tokens, key_tokens]\n            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.\n            _, query_tokens, _ = hidden_states.shape\n            attention_mask = attention_mask.expand(-1, query_tokens, -1)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        assert num_tasks == 2  # only support two tasks now\n\n        key_0, key_1 = torch.chunk(key, dim=0, chunks=2)  # keys shape (b t) d c\n        value_0, value_1 = torch.chunk(value, dim=0, chunks=2)\n        key = torch.cat([key_0, key_1], dim=1)  # (b t) 2d c\n        value = torch.cat([value_0, value_1], dim=1)  # (b t) 2d c\n        key = torch.cat([key]*2, dim=0)   # ( 2 b t) 2d c\n        value = torch.cat([value]*2, dim=0)  # (2 b t) 2d c\n\n        \n        query = attn.head_to_batch_dim(query).contiguous()\n        key = attn.head_to_batch_dim(key).contiguous()\n        value = attn.head_to_batch_dim(value).contiguous()\n\n        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n        \n        return hidden_states\n\n\nclass JointAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        num_tasks=2\n    ):\n        \n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        assert num_tasks == 2  # only support two tasks now\n\n        key_0, key_1 = torch.chunk(key, dim=0, chunks=2)  # keys shape (b t) d c\n        value_0, value_1 = torch.chunk(value, dim=0, chunks=2)\n        key = torch.cat([key_0, key_1], dim=1)  # (b t) 2d c\n        value = torch.cat([value_0, value_1], dim=1)  # (b t) 2d c\n        key = torch.cat([key]*2, dim=0)   # ( 2 b t) 2d c\n        value = torch.cat([value]*2, dim=0)  # (2 b t) 2d c\n\n        \n        query = attn.head_to_batch_dim(query).contiguous()\n        key = attn.head_to_batch_dim(key).contiguous()\n        value = attn.head_to_batch_dim(value).contiguous()\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n        \n        return hidden_states"
  },
  {
    "path": "core/models/unet_mv2d_blocks.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.utils import is_torch_version, logging\nfrom diffusers.models.attention import AdaGroupNorm\nfrom diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0\nfrom diffusers.models.dual_transformer_2d import DualTransformer2DModel\nfrom diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D\nfrom .transformer_mv2d import TransformerMV2DModel\n\nfrom diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D\nfrom diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nclass IdentityMLP(nn.Module):\n    def __init__(self, size):\n        super(IdentityMLP, self).__init__()\n        self.linear = nn.Linear(size, size)\n        self.init_identity()\n\n    def forward(self, x):\n\n        return self.linear(x)\n    \n    def init_identity(self):\n        # Initialize the weights to an identity matrix and biases to zero\n        identity_matrix = torch.eye(self.linear.in_features)\n        self.linear.weight.data.copy_(identity_matrix)\n        self.linear.bias.data.zero_()\n \n \ndef get_down_block(\n    down_block_type,\n    num_layers,\n    in_channels,\n    out_channels,\n    temb_channels,\n    add_downsample,\n    resnet_eps,\n    resnet_act_fn,\n    transformer_layers_per_block=1,\n    num_attention_heads=None,\n    resnet_groups=None,\n    cross_attention_dim=None,\n    downsample_padding=None,\n    dual_cross_attention=False,\n    use_linear_projection=False,\n    only_cross_attention=False,\n    upcast_attention=False,\n    resnet_time_scale_shift=\"default\",\n    resnet_skip_time_act=False,\n    resnet_out_scale_factor=1.0,\n    cross_attention_norm=None,\n    attention_head_dim=None,\n    downsample_type=None,\n    num_views=1,\n    cd_attention_last: bool = False,\n    cd_attention_mid: bool = False,\n    multiview_attention: bool = True,\n    sparse_mv_attention: bool = False,\n    mvcd_attention: bool=False\n):\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warn(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    down_block_type = down_block_type[7:] if down_block_type.startswith(\"UNetRes\") else down_block_type\n    if down_block_type == \"DownBlock2D\":\n        return DownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"ResnetDownsampleBlock2D\":\n        return ResnetDownsampleBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n        )\n    elif down_block_type == \"AttnDownBlock2D\":\n        if add_downsample is False:\n            downsample_type = None\n        else:\n            downsample_type = downsample_type or \"conv\"  # default to 'conv'\n        return AttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            downsample_type=downsample_type,\n        )\n    elif down_block_type == \"CrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnDownBlock2D\")\n        return CrossAttnDownBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    # custom MV2D attention block\n    elif down_block_type == \"CrossAttnDownBlockMV2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnDownBlockMV2D\")\n        return CrossAttnDownBlockMV2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            num_views=num_views,\n            cd_attention_last=cd_attention_last,\n            cd_attention_mid=cd_attention_mid,\n            multiview_attention=multiview_attention,\n            sparse_mv_attention=sparse_mv_attention,\n            mvcd_attention=mvcd_attention\n        )\n    elif down_block_type == \"SimpleCrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D\")\n        return SimpleCrossAttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n            only_cross_attention=only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n    elif down_block_type == \"SkipDownBlock2D\":\n        return SkipDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"AttnSkipDownBlock2D\":\n        return AttnSkipDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"DownEncoderBlock2D\":\n        return DownEncoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"AttnDownEncoderBlock2D\":\n        return AttnDownEncoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"KDownBlock2D\":\n        return KDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n        )\n    elif down_block_type == \"KCrossAttnDownBlock2D\":\n        return KCrossAttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            add_self_attention=True if not add_downsample else False,\n        )\n    raise ValueError(f\"{down_block_type} does not exist.\")\n\n\ndef get_up_block(\n    up_block_type,\n    num_layers,\n    in_channels,\n    out_channels,\n    prev_output_channel,\n    temb_channels,\n    add_upsample,\n    resnet_eps,\n    resnet_act_fn,\n    transformer_layers_per_block=1,\n    num_attention_heads=None,\n    resnet_groups=None,\n    cross_attention_dim=None,\n    dual_cross_attention=False,\n    use_linear_projection=False,\n    only_cross_attention=False,\n    upcast_attention=False,\n    resnet_time_scale_shift=\"default\",\n    resnet_skip_time_act=False,\n    resnet_out_scale_factor=1.0,\n    cross_attention_norm=None,\n    attention_head_dim=None,\n    upsample_type=None,\n    num_views=1,\n    cd_attention_last: bool = False,\n    cd_attention_mid: bool = False,\n    multiview_attention: bool = True,\n    sparse_mv_attention: bool = False,\n    mvcd_attention: bool=False\n):\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warn(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    up_block_type = up_block_type[7:] if up_block_type.startswith(\"UNetRes\") else up_block_type\n    if up_block_type == \"UpBlock2D\":\n        return UpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"ResnetUpsampleBlock2D\":\n        return ResnetUpsampleBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n        )\n    elif up_block_type == \"CrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnUpBlock2D\")\n        return CrossAttnUpBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    # custom MV2D attention block\n    elif up_block_type == \"CrossAttnUpBlockMV2D\":\n        # if cross_attention_dim is None:\n        #     raise ValueError(\"cross_attention_dim must be specified for CrossAttnUpBlockMV2D\")\n        return CrossAttnUpBlockMV2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            num_views=num_views,\n            cd_attention_last=cd_attention_last,\n            cd_attention_mid=cd_attention_mid,\n            multiview_attention=multiview_attention,\n            sparse_mv_attention=sparse_mv_attention,\n            mvcd_attention=mvcd_attention\n        )    \n    elif up_block_type == \"SimpleCrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D\")\n        return SimpleCrossAttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n            only_cross_attention=only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n    elif up_block_type == \"AttnUpBlock2D\":\n        if add_upsample is False:\n            upsample_type = None\n        else:\n            upsample_type = upsample_type or \"conv\"  # default to 'conv'\n\n        return AttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            upsample_type=upsample_type,\n        )\n    elif up_block_type == \"SkipUpBlock2D\":\n        return SkipUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"AttnSkipUpBlock2D\":\n        return AttnSkipUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"UpDecoderBlock2D\":\n        return UpDecoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n    elif up_block_type == \"AttnUpDecoderBlock2D\":\n        return AttnUpDecoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n    elif up_block_type == \"KUpBlock2D\":\n        return KUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n        )\n    elif up_block_type == \"KCrossAttnUpBlock2D\":\n        return KCrossAttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n        )\n\n    raise ValueError(f\"{up_block_type} does not exist.\")\n\n\nclass UNetMidBlockMV2DCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads=1,\n        output_scale_factor=1.0,\n        cross_attention_dim=1280,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        upcast_attention=False,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool=False\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n\n        for _ in range(num_layers):\n            if not dual_cross_attention:\n                attentions.append(\n                    TransformerMV2DModel(\n                        num_attention_heads,\n                        in_channels // num_attention_heads,\n                        in_channels=in_channels,\n                        num_layers=transformer_layers_per_block,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        upcast_attention=upcast_attention,\n                        num_views=num_views,\n                        cd_attention_last=cd_attention_last,\n                        cd_attention_mid=cd_attention_mid,\n                        multiview_attention=multiview_attention,\n                        sparse_mv_attention=sparse_mv_attention,\n                        mvcd_attention=mvcd_attention\n                    )\n                )\n            else:\n                raise NotImplementedError\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        ray_embedding: Optional[torch.Tensor] = None,\n    ) -> torch.FloatTensor:\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            hidden_states = attn(\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                cross_attention_kwargs=cross_attention_kwargs,\n                attention_mask=attention_mask,\n                encoder_attention_mask=encoder_attention_mask,\n                ray_embedding=ray_embedding,\n                return_dict=False,\n            )[0]\n            hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass CrossAttnUpBlockMV2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        only_cross_attention=False,\n        upcast_attention=False,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool=False\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n        mlps = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            mlps.append(IdentityMLP(res_skip_channels))\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if not dual_cross_attention:\n                attentions.append(\n                    TransformerMV2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=transformer_layers_per_block,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        only_cross_attention=only_cross_attention,\n                        upcast_attention=upcast_attention,\n                        num_views=num_views,\n                        cd_attention_last=cd_attention_last,\n                        cd_attention_mid=cd_attention_mid,\n                        multiview_attention=multiview_attention,\n                        sparse_mv_attention=sparse_mv_attention,\n                        mvcd_attention=mvcd_attention\n                    )\n                )\n            else:\n                raise NotImplementedError\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n        self.mlps = nn.ModuleList(mlps)\n        \n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        ray_embedding: Optional[torch.Tensor] = None,\n    ):\n        for resnet, attn, mlp in zip(self.resnets, self.attentions, self.mlps):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            B, _, H, W = res_hidden_states.shape\n            res_hidden_states = res_hidden_states.permute(0, 2, 3, 1).reshape(B, H * W, _)\n            res_hidden_states = mlp(res_hidden_states)\n            res_hidden_states = res_hidden_states.reshape(B, H, W, _).permute(0, 3, 1, 2).contiguous()\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(attn, return_dict=False),\n                    hidden_states,\n                    encoder_hidden_states,\n                    None,  # timestep\n                    None,  # class_labels\n                    cross_attention_kwargs,\n                    attention_mask,\n                    encoder_attention_mask,\n                    ray_embedding,\n                    **ckpt_kwargs,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    ray_embedding = ray_embedding,\n                    return_dict=False,\n                )[0]\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size)\n\n        return hidden_states\n\n\nclass CrossAttnDownBlockMV2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        downsample_padding=1,\n        add_downsample=True,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        only_cross_attention=False,\n        upcast_attention=False,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool=False\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if not dual_cross_attention:\n                attentions.append(\n                    TransformerMV2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=transformer_layers_per_block,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        only_cross_attention=only_cross_attention,\n                        upcast_attention=upcast_attention,\n                        num_views=num_views,\n                        cd_attention_last=cd_attention_last,\n                        cd_attention_mid=cd_attention_mid,\n                        multiview_attention=multiview_attention,\n                        sparse_mv_attention=sparse_mv_attention,\n                        mvcd_attention=mvcd_attention\n                    )\n                )\n            else:\n                raise NotImplementedError\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        additional_residuals=None,\n    ):\n        output_states = ()\n\n        blocks = list(zip(self.resnets, self.attentions))\n\n        for i, (resnet, attn) in enumerate(blocks):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(attn, return_dict=False),\n                    hidden_states,\n                    encoder_hidden_states,\n                    None,  # timestep\n                    None,  # class_labels\n                    cross_attention_kwargs,\n                    attention_mask,\n                    encoder_attention_mask,\n                    **ckpt_kwargs,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n            # apply additional residuals to the output of the last pair of resnet and attention blocks\n            if i == len(blocks) - 1 and additional_residuals is not None:\n                hidden_states = hidden_states + additional_residuals\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states"
  },
  {
    "path": "core/models/unet_mv2d_condition.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport os\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention_processor import AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model\nfrom diffusers.models.unet_2d_blocks import (\n    CrossAttnDownBlock2D,\n    CrossAttnUpBlock2D,\n    DownBlock2D,\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    UpBlock2D,\n)\nfrom diffusers.utils import (\n    CONFIG_NAME,\n    DIFFUSERS_CACHE,\n    FLAX_WEIGHTS_NAME,\n    HF_HUB_OFFLINE,\n    SAFETENSORS_WEIGHTS_NAME,\n    WEIGHTS_NAME,\n    _add_variant,\n    _get_model_file,\n    deprecate,\n    is_accelerate_available,\n    is_safetensors_available,\n    is_torch_version,\n    logging,\n)\nfrom diffusers import __version__\nfrom .unet_mv2d_blocks import (\n    CrossAttnDownBlockMV2D,\n    CrossAttnUpBlockMV2D,\n    UNetMidBlockMV2DCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNetMV2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlockMV2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool = False\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n            )\n        # custom MV2D attention block  \n        elif mid_block_type == \"UNetMidBlockMV2DCrossAttn\":\n            self.mid_block = UNetMidBlockMV2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        ### FIXME \n        #up_cross_attention_dim = (None, None, None, None)\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"set_processor\"):\n                processors[f\"{name}.processor\"] = module.processor\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        self.set_attn_processor(AttnProcessor())\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNetMV2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n            )\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNetMV2DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_2d(\n            cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],\n            camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 4, sample_size: int = 64,\n            zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,\n            projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, \n            cd_attention_mid: bool = False, multiview_attention: bool = True, \n            sparse_mv_attention: bool = False, mvcd_attention: bool = False,\n            in_channels: int = 10, out_channels: int = 15, \n            **kwargs\n        ):\n        r\"\"\"\n        Instantiate a pretrained PyTorch model from a pretrained model configuration.\n\n        The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To\n        train the model, set it back in training mode with `model.train()`.\n\n        Parameters:\n            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved\n                      with [`~ModelMixin.save_pretrained`].\n\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If `\"auto\"` is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info (`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only(`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            use_auth_token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            from_flax (`bool`, *optional*, defaults to `False`):\n                Load the model weights from a Flax checkpoint save file.\n            subfolder (`str`, *optional*, defaults to `\"\"`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn't need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if `device_map` contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            variant (`str`, *optional*):\n                Load weights from a specified `variant` filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the\n                `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`\n                weights. If set to `False`, `safetensors` weights are not loaded.\n\n        <Tip>\n\n        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with\n        `huggingface-cli login`. You can also activate the special\n        [\"offline-mode\"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a\n        firewalled environment.\n\n        </Tip>\n\n        Example:\n\n        ```py\n        from diffusers import UNet2DConditionModel\n\n        unet = UNet2DConditionModel.from_pretrained(\"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\")\n        ```\n\n        If you get the error message below, you need to finetune the weights for your downstream task:\n\n        ```bash\n        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:\n        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated\n        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n        ```\n        \"\"\"\n        cache_dir = kwargs.pop(\"cache_dir\", DIFFUSERS_CACHE)\n        ignore_mismatched_sizes = kwargs.pop(\"ignore_mismatched_sizes\", False)\n        force_download = kwargs.pop(\"force_download\", False)\n        from_flax = kwargs.pop(\"from_flax\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        output_loading_info = kwargs.pop(\"output_loading_info\", False)\n        local_files_only = kwargs.pop(\"local_files_only\", HF_HUB_OFFLINE)\n        use_auth_token = kwargs.pop(\"use_auth_token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        torch_dtype = kwargs.pop(\"torch_dtype\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n        device_map = kwargs.pop(\"device_map\", None)\n        max_memory = kwargs.pop(\"max_memory\", None)\n        offload_folder = kwargs.pop(\"offload_folder\", None)\n        offload_state_dict = kwargs.pop(\"offload_state_dict\", False)\n        variant = kwargs.pop(\"variant\", None)\n        use_safetensors = kwargs.pop(\"use_safetensors\", None)\n\n        if use_safetensors and not is_safetensors_available():\n            raise ValueError(\n                \"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors\"\n            )\n\n        allow_pickle = False\n        if use_safetensors is None:\n            use_safetensors = is_safetensors_available()\n            allow_pickle = True\n\n        if device_map is not None and not is_accelerate_available():\n            raise NotImplementedError(\n                \"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set\"\n                \" `device_map=None`. You can install accelerate with `pip install accelerate`.\"\n            )\n\n        # Check if we can handle device_map and dispatching the weights\n        if device_map is not None and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `device_map=None`.\"\n            )\n\n        # Load config if we don't provide a configuration\n        config_path = pretrained_model_name_or_path\n\n        user_agent = {\n            \"diffusers\": __version__,\n            \"file_type\": \"model\",\n            \"framework\": \"pytorch\",\n        }\n\n        # load config\n        config, unused_kwargs, commit_hash = cls.load_config(\n            config_path,\n            cache_dir=cache_dir,\n            return_unused_kwargs=True,\n            return_commit_hash=True,\n            force_download=force_download,\n            resume_download=resume_download,\n            proxies=proxies,\n            local_files_only=local_files_only,\n            use_auth_token=use_auth_token,\n            revision=revision,\n            subfolder=subfolder,\n            device_map=device_map,\n            max_memory=max_memory,\n            offload_folder=offload_folder,\n            offload_state_dict=offload_state_dict,\n            user_agent=user_agent,\n            **kwargs,\n        )\n\n        # modify config\n        config[\"_class_name\"] = cls.__name__\n        config['in_channels'] = in_channels\n        config['out_channels'] = out_channels\n        config['sample_size'] = sample_size # training resolution\n        config['num_views'] = num_views\n        config['cd_attention_last'] = cd_attention_last\n        config['cd_attention_mid'] = cd_attention_mid\n        config['multiview_attention'] = multiview_attention\n        config['sparse_mv_attention'] = sparse_mv_attention\n        config['mvcd_attention'] = mvcd_attention\n        config[\"down_block_types\"] = [\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"DownBlock2D\"\n        ]\n        config['mid_block_type'] = \"UNetMidBlockMV2DCrossAttn\"\n        config[\"up_block_types\"] = [\n            \"UpBlock2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\"\n        ]        \n        #config['class_embed_type'] = 'projection'\n        if camera_embedding_type == 'e_de_da_sincos':\n            config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6\n        else:\n            raise NotImplementedError\n\n        # load model\n        model_file = None\n        if from_flax:\n            raise NotImplementedError\n        else:\n            if use_safetensors:\n                try:\n                    model_file = _get_model_file(\n                        pretrained_model_name_or_path,\n                        weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),\n                        cache_dir=cache_dir,\n                        force_download=force_download,\n                        resume_download=resume_download,\n                        proxies=proxies,\n                        local_files_only=local_files_only,\n                        use_auth_token=use_auth_token,\n                        revision=revision,\n                        subfolder=subfolder,\n                        user_agent=user_agent,\n                        commit_hash=commit_hash,\n                    )\n                except IOError as e:\n                    if not allow_pickle:\n                        raise e\n                    pass\n            if model_file is None:\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path,\n                    weights_name=_add_variant(WEIGHTS_NAME, variant),\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    use_auth_token=use_auth_token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                    commit_hash=commit_hash,\n                )\n\n            model = cls.from_config(config, **unused_kwargs)\n            import copy\n            state_dict_v0 = load_state_dict(model_file, variant=variant)\n            state_dict = copy.deepcopy(state_dict_v0)\n            # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last\n            # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid\n            for key in state_dict_v0:\n                if 'attn_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint.\", \"attn_joint_last.\")] = state_dict.pop(tmp)\n                if 'norm_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint.\", \"norm_joint_last.\")] = state_dict.pop(tmp)\n                if 'attn_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint_twice.\", \"attn_joint_mid.\")] = state_dict.pop(tmp)\n                if 'norm_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint_twice.\", \"norm_joint_mid.\")] = state_dict.pop(tmp)\n            \n            model._convert_deprecated_attention_blocks(state_dict)\n\n            conv_in_weight = state_dict['conv_in.weight']\n            conv_out_weight = state_dict['conv_out.weight']\n            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(\n                model,\n                state_dict,\n                model_file,\n                pretrained_model_name_or_path,\n                ignore_mismatched_sizes=True,\n            )\n            if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):\n                # initialize from the original SD structure\n                model.conv_in.weight.data[:,:4] = conv_in_weight\n\n            # whether to place all zero to new layers?\n            if zero_init_conv_in:\n                model.conv_in.weight.data[:,4:] = 0.\n\n            if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):\n                # initialize from the original SD structure\n                model.conv_out.weight.data[-4:, ] = conv_out_weight\n                # model.conv_out.weight.data[:,:4] = conv_out_weight\n                # if out_channels == 8: # copy for the last 4 channels\n                #     model.conv_out.weight.data[:, 4:] = conv_out_weight\n            \n            if zero_init_camera_projection:\n                for p in model.class_embedding.parameters():\n                    torch.nn.init.zeros_(p)\n\n            loading_info = {\n                \"missing_keys\": missing_keys,\n                \"unexpected_keys\": unexpected_keys,\n                \"mismatched_keys\": mismatched_keys,\n                \"error_msgs\": error_msgs,\n            }\n\n        if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):\n            raise ValueError(\n                f\"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.\"\n            )\n        elif torch_dtype is not None:\n            model = model.to(torch_dtype)\n\n        model.register_to_config(_name_or_path=pretrained_model_name_or_path)\n\n        # Set model in evaluation mode to deactivate DropOut modules by default\n        model.eval()\n        if output_loading_info:\n            return model, loading_info\n\n        return model\n\n    @classmethod\n    def _load_pretrained_model_2d(\n        cls,\n        model,\n        state_dict,\n        resolved_archive_file,\n        pretrained_model_name_or_path,\n        ignore_mismatched_sizes=False,\n    ):\n        # Retrieve missing & unexpected_keys\n        model_state_dict = model.state_dict()\n        loaded_keys = list(state_dict.keys())\n\n        expected_keys = list(model_state_dict.keys())\n\n        original_loaded_keys = loaded_keys\n\n        missing_keys = list(set(expected_keys) - set(loaded_keys))\n        unexpected_keys = list(set(loaded_keys) - set(expected_keys))\n\n        # Make sure we are able to load base models as well as derived models (with heads)\n        model_to_load = model\n\n        def _find_mismatched_keys(\n            state_dict,\n            model_state_dict,\n            loaded_keys,\n            ignore_mismatched_sizes,\n        ):\n            mismatched_keys = []\n            if ignore_mismatched_sizes:\n                for checkpoint_key in loaded_keys:\n                    model_key = checkpoint_key\n\n                    if (\n                        model_key in model_state_dict\n                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape\n                    ):\n                        mismatched_keys.append(\n                            (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)\n                        )\n                        del state_dict[checkpoint_key]\n            return mismatched_keys\n\n        if state_dict is not None:\n            # Whole checkpoint\n            mismatched_keys = _find_mismatched_keys(\n                state_dict,\n                model_state_dict,\n                original_loaded_keys,\n                ignore_mismatched_sizes,\n            )\n            error_msgs = _load_state_dict_into_model(model_to_load, state_dict)\n\n        if len(error_msgs) > 0:\n            error_msg = \"\\n\\t\".join(error_msgs)\n            if \"size mismatch\" in error_msg:\n                error_msg += (\n                    \"\\n\\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.\"\n                )\n            raise RuntimeError(f\"Error(s) in loading state_dict for {model.__class__.__name__}:\\n\\t{error_msg}\")\n\n        if len(unexpected_keys) > 0:\n            logger.warning(\n                f\"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when\"\n                f\" initializing {model.__class__.__name__}: {unexpected_keys}\\n- This IS expected if you are\"\n                f\" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task\"\n                \" or with another architecture (e.g. initializing a BertForSequenceClassification model from a\"\n                \" BertForPreTraining model).\\n- This IS NOT expected if you are initializing\"\n                f\" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly\"\n                \" identical (initializing a BertForSequenceClassification model from a\"\n                \" BertForSequenceClassification model).\"\n            )\n        else:\n            logger.info(f\"All model checkpoint weights were used when initializing {model.__class__.__name__}.\\n\")\n        if len(missing_keys) > 0:\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\\nYou should probably\"\n                \" TRAIN this model on a down-stream task to be able to use it for predictions and inference.\"\n            )\n        elif len(mismatched_keys) == 0:\n            logger.info(\n                f\"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path}.\\nIf your task is similar to the task the model of the\"\n                f\" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions\"\n                \" without further training.\"\n            )\n        if len(mismatched_keys) > 0:\n            mismatched_warning = \"\\n\".join(\n                [\n                    f\"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated\"\n                    for key, shape1, shape2 in mismatched_keys\n                ]\n            )\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized because the shapes did not\"\n                f\" match:\\n{mismatched_warning}\\nYou should probably TRAIN this model on a down-stream task to be\"\n                \" able to use it for predictions and inference.\"\n            )\n\n        return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs"
  },
  {
    "path": "core/models/unet_mv2d_condition_depth.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport os\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention_processor import AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model\nfrom diffusers.models.unet_2d_blocks import (\n    CrossAttnDownBlock2D,\n    CrossAttnUpBlock2D,\n    DownBlock2D,\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    UpBlock2D,\n)\nfrom diffusers.utils import (\n    CONFIG_NAME,\n    DIFFUSERS_CACHE,\n    FLAX_WEIGHTS_NAME,\n    HF_HUB_OFFLINE,\n    SAFETENSORS_WEIGHTS_NAME,\n    WEIGHTS_NAME,\n    _add_variant,\n    _get_model_file,\n    deprecate,\n    is_accelerate_available,\n    is_safetensors_available,\n    is_torch_version,\n    logging,\n)\nfrom diffusers import __version__\nfrom .unet_mv2d_blocks import (\n    CrossAttnDownBlockMV2D,\n    CrossAttnUpBlockMV2D,\n    UNetMidBlockMV2DCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNetMV2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlockMV2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool = False\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n            )\n        # custom MV2D attention block  \n        elif mid_block_type == \"UNetMidBlockMV2DCrossAttn\":\n            self.mid_block = UNetMidBlockMV2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        ### FIXME \n        #up_cross_attention_dim = (None, None, None, None)\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"set_processor\"):\n                processors[f\"{name}.processor\"] = module.processor\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        self.set_attn_processor(AttnProcessor())\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNetMV2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n            )\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNetMV2DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_2d(\n            cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],\n            camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 4, sample_size: int = 64,\n            zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,\n            projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, \n            cd_attention_mid: bool = False, multiview_attention: bool = True, \n            sparse_mv_attention: bool = False, mvcd_attention: bool = False,\n            in_channels: int = 10, out_channels: int = 13, \n            **kwargs\n        ):\n        r\"\"\"\n        Instantiate a pretrained PyTorch model from a pretrained model configuration.\n\n        The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To\n        train the model, set it back in training mode with `model.train()`.\n\n        Parameters:\n            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved\n                      with [`~ModelMixin.save_pretrained`].\n\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If `\"auto\"` is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info (`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only(`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            use_auth_token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            from_flax (`bool`, *optional*, defaults to `False`):\n                Load the model weights from a Flax checkpoint save file.\n            subfolder (`str`, *optional*, defaults to `\"\"`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn't need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if `device_map` contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            variant (`str`, *optional*):\n                Load weights from a specified `variant` filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the\n                `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`\n                weights. If set to `False`, `safetensors` weights are not loaded.\n\n        <Tip>\n\n        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with\n        `huggingface-cli login`. You can also activate the special\n        [\"offline-mode\"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a\n        firewalled environment.\n\n        </Tip>\n\n        Example:\n\n        ```py\n        from diffusers import UNet2DConditionModel\n\n        unet = UNet2DConditionModel.from_pretrained(\"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\")\n        ```\n\n        If you get the error message below, you need to finetune the weights for your downstream task:\n\n        ```bash\n        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:\n        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated\n        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n        ```\n        \"\"\"\n        cache_dir = kwargs.pop(\"cache_dir\", DIFFUSERS_CACHE)\n        ignore_mismatched_sizes = kwargs.pop(\"ignore_mismatched_sizes\", False)\n        force_download = kwargs.pop(\"force_download\", False)\n        from_flax = kwargs.pop(\"from_flax\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        output_loading_info = kwargs.pop(\"output_loading_info\", False)\n        local_files_only = kwargs.pop(\"local_files_only\", HF_HUB_OFFLINE)\n        use_auth_token = kwargs.pop(\"use_auth_token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        torch_dtype = kwargs.pop(\"torch_dtype\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n        device_map = kwargs.pop(\"device_map\", None)\n        max_memory = kwargs.pop(\"max_memory\", None)\n        offload_folder = kwargs.pop(\"offload_folder\", None)\n        offload_state_dict = kwargs.pop(\"offload_state_dict\", False)\n        variant = kwargs.pop(\"variant\", None)\n        use_safetensors = kwargs.pop(\"use_safetensors\", None)\n\n        if use_safetensors and not is_safetensors_available():\n            raise ValueError(\n                \"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors\"\n            )\n\n        allow_pickle = False\n        if use_safetensors is None:\n            use_safetensors = is_safetensors_available()\n            allow_pickle = True\n\n        if device_map is not None and not is_accelerate_available():\n            raise NotImplementedError(\n                \"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set\"\n                \" `device_map=None`. You can install accelerate with `pip install accelerate`.\"\n            )\n\n        # Check if we can handle device_map and dispatching the weights\n        if device_map is not None and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `device_map=None`.\"\n            )\n\n        # Load config if we don't provide a configuration\n        config_path = pretrained_model_name_or_path\n\n        user_agent = {\n            \"diffusers\": __version__,\n            \"file_type\": \"model\",\n            \"framework\": \"pytorch\",\n        }\n\n        # load config\n        config, unused_kwargs, commit_hash = cls.load_config(\n            config_path,\n            cache_dir=cache_dir,\n            return_unused_kwargs=True,\n            return_commit_hash=True,\n            force_download=force_download,\n            resume_download=resume_download,\n            proxies=proxies,\n            local_files_only=local_files_only,\n            use_auth_token=use_auth_token,\n            revision=revision,\n            subfolder=subfolder,\n            device_map=device_map,\n            max_memory=max_memory,\n            offload_folder=offload_folder,\n            offload_state_dict=offload_state_dict,\n            user_agent=user_agent,\n            **kwargs,\n        )\n\n        # modify config\n        config[\"_class_name\"] = cls.__name__\n        config['in_channels'] = in_channels\n        config['out_channels'] = out_channels\n        config['sample_size'] = sample_size # training resolution\n        config['num_views'] = num_views\n        config['cd_attention_last'] = cd_attention_last\n        config['cd_attention_mid'] = cd_attention_mid\n        config['multiview_attention'] = multiview_attention\n        config['sparse_mv_attention'] = sparse_mv_attention\n        config['mvcd_attention'] = mvcd_attention\n        config[\"down_block_types\"] = [\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"DownBlock2D\"\n        ]\n        config['mid_block_type'] = \"UNetMidBlockMV2DCrossAttn\"\n        config[\"up_block_types\"] = [\n            \"UpBlock2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\"\n        ]        \n        #config['class_embed_type'] = 'projection'\n        if camera_embedding_type == 'e_de_da_sincos':\n            config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6\n        else:\n            raise NotImplementedError\n\n        # load model\n        model_file = None\n        if from_flax:\n            raise NotImplementedError\n        else:\n            if use_safetensors:\n                try:\n                    model_file = _get_model_file(\n                        pretrained_model_name_or_path,\n                        weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),\n                        cache_dir=cache_dir,\n                        force_download=force_download,\n                        resume_download=resume_download,\n                        proxies=proxies,\n                        local_files_only=local_files_only,\n                        use_auth_token=use_auth_token,\n                        revision=revision,\n                        subfolder=subfolder,\n                        user_agent=user_agent,\n                        commit_hash=commit_hash,\n                    )\n                except IOError as e:\n                    if not allow_pickle:\n                        raise e\n                    pass\n            if model_file is None:\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path,\n                    weights_name=_add_variant(WEIGHTS_NAME, variant),\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    use_auth_token=use_auth_token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                    commit_hash=commit_hash,\n                )\n\n            model = cls.from_config(config, **unused_kwargs)\n            import copy\n            state_dict_v0 = load_state_dict(model_file, variant=variant)\n            state_dict = copy.deepcopy(state_dict_v0)\n            # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last\n            # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid\n            for key in state_dict_v0:\n                if 'attn_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint.\", \"attn_joint_last.\")] = state_dict.pop(tmp)\n                if 'norm_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint.\", \"norm_joint_last.\")] = state_dict.pop(tmp)\n                if 'attn_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint_twice.\", \"attn_joint_mid.\")] = state_dict.pop(tmp)\n                if 'norm_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint_twice.\", \"norm_joint_mid.\")] = state_dict.pop(tmp)\n            \n            model._convert_deprecated_attention_blocks(state_dict)\n\n            conv_in_weight = state_dict['conv_in.weight']\n            conv_out_weight = state_dict['conv_out.weight']\n            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(\n                model,\n                state_dict,\n                model_file,\n                pretrained_model_name_or_path,\n                ignore_mismatched_sizes=True,\n            )\n            if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):\n                # initialize from the original SD structure\n                model.conv_in.weight.data[:,:4] = conv_in_weight\n\n            # whether to place all zero to new layers?\n            if zero_init_conv_in:\n                model.conv_in.weight.data[:,4:] = 0.\n\n            if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):\n                # initialize from the original SD structure\n                model.conv_out.weight.data[-4:, ] = conv_out_weight\n                # model.conv_out.weight.data[:,:4] = conv_out_weight\n                # if out_channels == 8: # copy for the last 4 channels\n                #     model.conv_out.weight.data[:, 4:] = conv_out_weight\n            \n            if zero_init_camera_projection:\n                for p in model.class_embedding.parameters():\n                    torch.nn.init.zeros_(p)\n\n            loading_info = {\n                \"missing_keys\": missing_keys,\n                \"unexpected_keys\": unexpected_keys,\n                \"mismatched_keys\": mismatched_keys,\n                \"error_msgs\": error_msgs,\n            }\n\n        if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):\n            raise ValueError(\n                f\"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.\"\n            )\n        elif torch_dtype is not None:\n            model = model.to(torch_dtype)\n\n        model.register_to_config(_name_or_path=pretrained_model_name_or_path)\n\n        # Set model in evaluation mode to deactivate DropOut modules by default\n        model.eval()\n        if output_loading_info:\n            return model, loading_info\n\n        return model\n\n    @classmethod\n    def _load_pretrained_model_2d(\n        cls,\n        model,\n        state_dict,\n        resolved_archive_file,\n        pretrained_model_name_or_path,\n        ignore_mismatched_sizes=False,\n    ):\n        # Retrieve missing & unexpected_keys\n        model_state_dict = model.state_dict()\n        loaded_keys = list(state_dict.keys())\n\n        expected_keys = list(model_state_dict.keys())\n\n        original_loaded_keys = loaded_keys\n\n        missing_keys = list(set(expected_keys) - set(loaded_keys))\n        unexpected_keys = list(set(loaded_keys) - set(expected_keys))\n\n        # Make sure we are able to load base models as well as derived models (with heads)\n        model_to_load = model\n\n        def _find_mismatched_keys(\n            state_dict,\n            model_state_dict,\n            loaded_keys,\n            ignore_mismatched_sizes,\n        ):\n            mismatched_keys = []\n            if ignore_mismatched_sizes:\n                for checkpoint_key in loaded_keys:\n                    model_key = checkpoint_key\n\n                    if (\n                        model_key in model_state_dict\n                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape\n                    ):\n                        mismatched_keys.append(\n                            (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)\n                        )\n                        del state_dict[checkpoint_key]\n            return mismatched_keys\n\n        if state_dict is not None:\n            # Whole checkpoint\n            mismatched_keys = _find_mismatched_keys(\n                state_dict,\n                model_state_dict,\n                original_loaded_keys,\n                ignore_mismatched_sizes,\n            )\n            error_msgs = _load_state_dict_into_model(model_to_load, state_dict)\n\n        if len(error_msgs) > 0:\n            error_msg = \"\\n\\t\".join(error_msgs)\n            if \"size mismatch\" in error_msg:\n                error_msg += (\n                    \"\\n\\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.\"\n                )\n            raise RuntimeError(f\"Error(s) in loading state_dict for {model.__class__.__name__}:\\n\\t{error_msg}\")\n\n        if len(unexpected_keys) > 0:\n            logger.warning(\n                f\"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when\"\n                f\" initializing {model.__class__.__name__}: {unexpected_keys}\\n- This IS expected if you are\"\n                f\" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task\"\n                \" or with another architecture (e.g. initializing a BertForSequenceClassification model from a\"\n                \" BertForPreTraining model).\\n- This IS NOT expected if you are initializing\"\n                f\" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly\"\n                \" identical (initializing a BertForSequenceClassification model from a\"\n                \" BertForSequenceClassification model).\"\n            )\n        else:\n            logger.info(f\"All model checkpoint weights were used when initializing {model.__class__.__name__}.\\n\")\n        if len(missing_keys) > 0:\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\\nYou should probably\"\n                \" TRAIN this model on a down-stream task to be able to use it for predictions and inference.\"\n            )\n        elif len(mismatched_keys) == 0:\n            logger.info(\n                f\"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path}.\\nIf your task is similar to the task the model of the\"\n                f\" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions\"\n                \" without further training.\"\n            )\n        if len(mismatched_keys) > 0:\n            mismatched_warning = \"\\n\".join(\n                [\n                    f\"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated\"\n                    for key, shape1, shape2 in mismatched_keys\n                ]\n            )\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized because the shapes did not\"\n                f\" match:\\n{mismatched_warning}\\nYou should probably TRAIN this model on a down-stream task to be\"\n                \" able to use it for predictions and inference.\"\n            )\n\n        return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs"
  },
  {
    "path": "core/models/unet_mv2d_condition_depth_diffusion.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport os\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention_processor import AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model\nfrom diffusers.models.unet_2d_blocks import (\n    CrossAttnDownBlock2D,\n    CrossAttnUpBlock2D,\n    DownBlock2D,\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    UpBlock2D,\n)\nfrom diffusers.utils import (\n    CONFIG_NAME,\n    DIFFUSERS_CACHE,\n    FLAX_WEIGHTS_NAME,\n    HF_HUB_OFFLINE,\n    SAFETENSORS_WEIGHTS_NAME,\n    WEIGHTS_NAME,\n    _add_variant,\n    _get_model_file,\n    deprecate,\n    is_accelerate_available,\n    is_safetensors_available,\n    is_torch_version,\n    logging,\n)\nfrom diffusers import __version__\nfrom .unet_mv2d_blocks import (\n    CrossAttnDownBlockMV2D,\n    CrossAttnUpBlockMV2D,\n    UNetMidBlockMV2DCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nclass IdentityMLP(nn.Module):\n    def __init__(self, size):\n        super(IdentityMLP, self).__init__()\n        self.linear = nn.Linear(size, size)\n        self.init_identity()\n\n    def forward(self, x):\n        return self.linear(x)\n    \n    def init_identity(self):\n        # Initialize the weights to an identity matrix and biases to zero\n        identity_matrix = torch.eye(self.linear.in_features)\n        self.linear.weight.data.copy_(identity_matrix)\n        self.linear.bias.data.zero_()\n        \n@dataclass\nclass UNetMV2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlockMV2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool = False\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n            )\n        # custom MV2D attention block  \n        elif mid_block_type == \"UNetMidBlockMV2DCrossAttn\":\n            self.mid_block = UNetMidBlockMV2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        ### FIXME \n        #up_cross_attention_dim = (None, None, None, None)\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"set_processor\"):\n                processors[f\"{name}.processor\"] = module.processor\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        self.set_attn_processor(AttnProcessor())\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        ray_embedding: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNetMV2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n                ray_embedding = ray_embedding,\n            )\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    ray_embedding = ray_embedding,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNetMV2DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_2d(\n            cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],\n            camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 4, sample_size: int = 64,\n            zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,\n            projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, \n            cd_attention_mid: bool = False, multiview_attention: bool = True, \n            sparse_mv_attention: bool = False, mvcd_attention: bool = False,\n            in_channels: int = 4, out_channels: int = 13, \n            **kwargs\n        ):\n        r\"\"\"\n        Instantiate a pretrained PyTorch model from a pretrained model configuration.\n\n        The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To\n        train the model, set it back in training mode with `model.train()`.\n\n        Parameters:\n            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved\n                      with [`~ModelMixin.save_pretrained`].\n\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If `\"auto\"` is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info (`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only(`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            use_auth_token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            from_flax (`bool`, *optional*, defaults to `False`):\n                Load the model weights from a Flax checkpoint save file.\n            subfolder (`str`, *optional*, defaults to `\"\"`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn't need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if `device_map` contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            variant (`str`, *optional*):\n                Load weights from a specified `variant` filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the\n                `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`\n                weights. If set to `False`, `safetensors` weights are not loaded.\n\n        <Tip>\n\n        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with\n        `huggingface-cli login`. You can also activate the special\n        [\"offline-mode\"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a\n        firewalled environment.\n\n        </Tip>\n\n        Example:\n\n        ```py\n        from diffusers import UNet2DConditionModel\n\n        unet = UNet2DConditionModel.from_pretrained(\"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\")\n        ```\n\n        If you get the error message below, you need to finetune the weights for your downstream task:\n\n        ```bash\n        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:\n        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated\n        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n        ```\n        \"\"\"\n        cache_dir = kwargs.pop(\"cache_dir\", DIFFUSERS_CACHE)\n        ignore_mismatched_sizes = kwargs.pop(\"ignore_mismatched_sizes\", False)\n        force_download = kwargs.pop(\"force_download\", False)\n        from_flax = kwargs.pop(\"from_flax\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        output_loading_info = kwargs.pop(\"output_loading_info\", False)\n        local_files_only = kwargs.pop(\"local_files_only\", HF_HUB_OFFLINE)\n        use_auth_token = kwargs.pop(\"use_auth_token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        torch_dtype = kwargs.pop(\"torch_dtype\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n        device_map = kwargs.pop(\"device_map\", None)\n        max_memory = kwargs.pop(\"max_memory\", None)\n        offload_folder = kwargs.pop(\"offload_folder\", None)\n        offload_state_dict = kwargs.pop(\"offload_state_dict\", False)\n        variant = kwargs.pop(\"variant\", None)\n        use_safetensors = kwargs.pop(\"use_safetensors\", None)\n\n        if use_safetensors and not is_safetensors_available():\n            raise ValueError(\n                \"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors\"\n            )\n\n        allow_pickle = False\n        if use_safetensors is None:\n            use_safetensors = is_safetensors_available()\n            allow_pickle = True\n\n        if device_map is not None and not is_accelerate_available():\n            raise NotImplementedError(\n                \"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set\"\n                \" `device_map=None`. You can install accelerate with `pip install accelerate`.\"\n            )\n\n        # Check if we can handle device_map and dispatching the weights\n        if device_map is not None and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `device_map=None`.\"\n            )\n\n        # Load config if we don't provide a configuration\n        config_path = pretrained_model_name_or_path\n\n        user_agent = {\n            \"diffusers\": __version__,\n            \"file_type\": \"model\",\n            \"framework\": \"pytorch\",\n        }\n\n        # load config\n        config, unused_kwargs, commit_hash = cls.load_config(\n            config_path,\n            cache_dir=cache_dir,\n            return_unused_kwargs=True,\n            return_commit_hash=True,\n            force_download=force_download,\n            resume_download=resume_download,\n            proxies=proxies,\n            local_files_only=local_files_only,\n            use_auth_token=use_auth_token,\n            revision=revision,\n            subfolder=subfolder,\n            device_map=device_map,\n            max_memory=max_memory,\n            offload_folder=offload_folder,\n            offload_state_dict=offload_state_dict,\n            user_agent=user_agent,\n            **kwargs,\n        )\n\n        # modify config\n        config[\"_class_name\"] = cls.__name__\n        config['in_channels'] = in_channels\n        config['out_channels'] = out_channels\n        config['sample_size'] = sample_size # training resolution\n        config['num_views'] = num_views\n        config['cd_attention_last'] = cd_attention_last\n        config['cd_attention_mid'] = cd_attention_mid\n        config['multiview_attention'] = multiview_attention\n        config['sparse_mv_attention'] = sparse_mv_attention\n        config['mvcd_attention'] = mvcd_attention\n        config[\"down_block_types\"] = [\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\"\n        ]\n        config['mid_block_type'] = \"UNetMidBlockMV2DCrossAttn\"\n        config[\"up_block_types\"] = [\n            \"UpBlock2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\"\n        ]        \n        #config['class_embed_type'] = 'projection'\n        if camera_embedding_type == 'e_de_da_sincos':\n            config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6\n        else:\n            raise NotImplementedError\n\n        # load model\n        model_file = None\n        if from_flax:\n            raise NotImplementedError\n        else:\n            if use_safetensors:\n                try:\n                    model_file = _get_model_file(\n                        pretrained_model_name_or_path,\n                        weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),\n                        cache_dir=cache_dir,\n                        force_download=force_download,\n                        resume_download=resume_download,\n                        proxies=proxies,\n                        local_files_only=local_files_only,\n                        use_auth_token=use_auth_token,\n                        revision=revision,\n                        subfolder=subfolder,\n                        user_agent=user_agent,\n                        commit_hash=commit_hash,\n                    )\n                except IOError as e:\n                    if not allow_pickle:\n                        raise e\n                    pass\n            if model_file is None:\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path,\n                    weights_name=_add_variant(WEIGHTS_NAME, variant),\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    use_auth_token=use_auth_token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                    commit_hash=commit_hash,\n                )\n\n            model = cls.from_config(config, **unused_kwargs)\n            import copy\n            state_dict_v0 = load_state_dict(model_file, variant=variant)\n            state_dict = copy.deepcopy(state_dict_v0)\n            # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last\n            # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid\n            for key in state_dict_v0:\n                if 'attn_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint.\", \"attn_joint_last.\")] = state_dict.pop(tmp)\n                if 'norm_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint.\", \"norm_joint_last.\")] = state_dict.pop(tmp)\n                if 'attn_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint_twice.\", \"attn_joint_mid.\")] = state_dict.pop(tmp)\n                if 'norm_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint_twice.\", \"norm_joint_mid.\")] = state_dict.pop(tmp)\n            \n            model._convert_deprecated_attention_blocks(state_dict)\n\n            conv_in_weight = state_dict['conv_in.weight']\n            conv_out_weight = state_dict['conv_out.weight']\n            conv_out_bias = state_dict['conv_out.bias']\n            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(\n                model,\n                state_dict,\n                model_file,\n                pretrained_model_name_or_path,\n                ignore_mismatched_sizes=True,\n            )\n            # if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):\n            #     # initialize from the original SD structure\n            #     model.conv_in.weight.data[:,:4] = conv_in_weight\n\n            # # whether to place all zero to new layers?\n            # if zero_init_conv_in:\n            #     model.conv_in.weight.data[:,4:] = 0.\n\n            if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):\n                # initialize from the original SD structure\n                model.conv_out.weight.data[-4:, ] = conv_out_weight\n                model.conv_out.bias.data[-4:] = conv_out_bias\n                # model.conv_out.weight.data[:,:4] = conv_out_weight\n                # if out_channels == 8: # copy for the last 4 channels\n                #     model.conv_out.weight.data[:, 4:] = conv_out_weight\n            \n            if zero_init_camera_projection:\n                for p in model.class_embedding.parameters():\n                    torch.nn.init.zeros_(p)\n\n            loading_info = {\n                \"missing_keys\": missing_keys,\n                \"unexpected_keys\": unexpected_keys,\n                \"mismatched_keys\": mismatched_keys,\n                \"error_msgs\": error_msgs,\n            }\n\n        if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):\n            raise ValueError(\n                f\"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.\"\n            )\n        elif torch_dtype is not None:\n            model = model.to(torch_dtype)\n\n        model.register_to_config(_name_or_path=pretrained_model_name_or_path)\n\n        # Set model in evaluation mode to deactivate DropOut modules by default\n        model.eval()\n        if output_loading_info:\n            return model, loading_info\n\n        return model\n\n    @classmethod\n    def _load_pretrained_model_2d(\n        cls,\n        model,\n        state_dict,\n        resolved_archive_file,\n        pretrained_model_name_or_path,\n        ignore_mismatched_sizes=False,\n    ):\n        # Retrieve missing & unexpected_keys\n        model_state_dict = model.state_dict()\n        loaded_keys = list(state_dict.keys())\n\n        expected_keys = list(model_state_dict.keys())\n\n        original_loaded_keys = loaded_keys\n\n        missing_keys = list(set(expected_keys) - set(loaded_keys))\n        unexpected_keys = list(set(loaded_keys) - set(expected_keys))\n\n        # Make sure we are able to load base models as well as derived models (with heads)\n        model_to_load = model\n\n        def _find_mismatched_keys(\n            state_dict,\n            model_state_dict,\n            loaded_keys,\n            ignore_mismatched_sizes,\n        ):\n            mismatched_keys = []\n            if ignore_mismatched_sizes:\n                for checkpoint_key in loaded_keys:\n                    model_key = checkpoint_key\n\n                    if (\n                        model_key in model_state_dict\n                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape\n                    ):\n                        mismatched_keys.append(\n                            (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)\n                        )\n                        if 'proj_in' in checkpoint_key:\n                            state_dict[checkpoint_key] = torch.cat([state_dict[checkpoint_key], model_state_dict[checkpoint_key][:, -6:]], dim=1)\n                        else:\n                            del state_dict[checkpoint_key]\n            return mismatched_keys\n\n        if state_dict is not None:\n            # Whole checkpoint\n            mismatched_keys = _find_mismatched_keys(\n                state_dict,\n                model_state_dict,\n                original_loaded_keys,\n                ignore_mismatched_sizes,\n            )\n            error_msgs = _load_state_dict_into_model(model_to_load, state_dict)\n\n        if len(error_msgs) > 0:\n            error_msg = \"\\n\\t\".join(error_msgs)\n            if \"size mismatch\" in error_msg:\n                error_msg += (\n                    \"\\n\\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.\"\n                )\n            raise RuntimeError(f\"Error(s) in loading state_dict for {model.__class__.__name__}:\\n\\t{error_msg}\")\n\n        if len(unexpected_keys) > 0:\n            logger.warning(\n                f\"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when\"\n                f\" initializing {model.__class__.__name__}: {unexpected_keys}\\n- This IS expected if you are\"\n                f\" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task\"\n                \" or with another architecture (e.g. initializing a BertForSequenceClassification model from a\"\n                \" BertForPreTraining model).\\n- This IS NOT expected if you are initializing\"\n                f\" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly\"\n                \" identical (initializing a BertForSequenceClassification model from a\"\n                \" BertForSequenceClassification model).\"\n            )\n        else:\n            logger.info(f\"All model checkpoint weights were used when initializing {model.__class__.__name__}.\\n\")\n        if len(missing_keys) > 0:\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\\nYou should probably\"\n                \" TRAIN this model on a down-stream task to be able to use it for predictions and inference.\"\n            )\n        elif len(mismatched_keys) == 0:\n            logger.info(\n                f\"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path}.\\nIf your task is similar to the task the model of the\"\n                f\" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions\"\n                \" without further training.\"\n            )\n        if len(mismatched_keys) > 0:\n            mismatched_warning = \"\\n\".join(\n                [\n                    f\"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated\"\n                    for key, shape1, shape2 in mismatched_keys\n                ]\n            )\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized because the shapes did not\"\n                f\" match:\\n{mismatched_warning}\\nYou should probably TRAIN this model on a down-stream task to be\"\n                \" able to use it for predictions and inference.\"\n            )\n\n        return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs"
  },
  {
    "path": "core/models/unet_mv2d_condition_depth_diffusion_test.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport os\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention_processor import AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model\nfrom diffusers.models.unet_2d_blocks import (\n    CrossAttnDownBlock2D,\n    CrossAttnUpBlock2D,\n    DownBlock2D,\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    UpBlock2D,\n)\nfrom diffusers.utils import (\n    CONFIG_NAME,\n    DIFFUSERS_CACHE,\n    FLAX_WEIGHTS_NAME,\n    HF_HUB_OFFLINE,\n    SAFETENSORS_WEIGHTS_NAME,\n    WEIGHTS_NAME,\n    _add_variant,\n    _get_model_file,\n    deprecate,\n    is_accelerate_available,\n    is_safetensors_available,\n    is_torch_version,\n    logging,\n)\nfrom diffusers import __version__\nfrom .unet_mv2d_blocks import (\n    CrossAttnDownBlockMV2D,\n    CrossAttnUpBlockMV2D,\n    UNetMidBlockMV2DCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nclass IdentityMLP(nn.Module):\n    def __init__(self, size):\n        super(IdentityMLP, self).__init__()\n        self.linear = nn.Linear(size, size)\n        self.init_identity()\n\n    def forward(self, x):\n        return self.linear(x)\n    \n    def init_identity(self):\n        # Initialize the weights to an identity matrix and biases to zero\n        identity_matrix = torch.eye(self.linear.in_features)\n        self.linear.weight.data.copy_(identity_matrix)\n        self.linear.bias.data.zero_()\n        \n@dataclass\nclass UNetMV2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"CrossAttnDownBlockMV2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlockMV2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\", \"CrossAttnUpBlockMV2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n        num_views: int = 1,\n        cd_attention_last: bool = False,\n        cd_attention_mid: bool = False,\n        multiview_attention: bool = True,\n        sparse_mv_attention: bool = False,\n        mvcd_attention: bool = False\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n            )\n        # custom MV2D attention block  \n        elif mid_block_type == \"UNetMidBlockMV2DCrossAttn\":\n            self.mid_block = UNetMidBlockMV2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        ### FIXME \n        #up_cross_attention_dim = (None, None, None, None)\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        #reversed_cross_attention_dim = list(reversed(up_cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                num_views=num_views,\n                cd_attention_last=cd_attention_last,\n                cd_attention_mid=cd_attention_mid,\n                multiview_attention=multiview_attention,\n                sparse_mv_attention=sparse_mv_attention,\n                mvcd_attention=mvcd_attention\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"set_processor\"):\n                processors[f\"{name}.processor\"] = module.processor\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        self.set_attn_processor(AttnProcessor())\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        ray_embedding: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNetMV2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n                ray_embedding = ray_embedding,\n            )\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    ray_embedding = ray_embedding,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNetMV2DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_2d(\n            cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],\n            camera_embedding_type: str = 'e_de_da_sincos', num_views: int = 1, sample_size: int = 64,\n            zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,\n            projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False, \n            cd_attention_mid: bool = False, multiview_attention: bool = True, \n            sparse_mv_attention: bool = False, mvcd_attention: bool = False,\n            in_channels: int = 4, out_channels: int = 13, \n            **kwargs\n        ):\n        r\"\"\"\n        Instantiate a pretrained PyTorch model from a pretrained model configuration.\n\n        The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To\n        train the model, set it back in training mode with `model.train()`.\n\n        Parameters:\n            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved\n                      with [`~ModelMixin.save_pretrained`].\n\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If `\"auto\"` is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info (`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only(`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            use_auth_token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            from_flax (`bool`, *optional*, defaults to `False`):\n                Load the model weights from a Flax checkpoint save file.\n            subfolder (`str`, *optional*, defaults to `\"\"`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn't need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if `device_map` contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            variant (`str`, *optional*):\n                Load weights from a specified `variant` filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the\n                `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`\n                weights. If set to `False`, `safetensors` weights are not loaded.\n\n        <Tip>\n\n        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with\n        `huggingface-cli login`. You can also activate the special\n        [\"offline-mode\"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a\n        firewalled environment.\n\n        </Tip>\n\n        Example:\n\n        ```py\n        from diffusers import UNet2DConditionModel\n\n        unet = UNet2DConditionModel.from_pretrained(\"runwayml/stable-diffusion-v1-5\", subfolder=\"unet\")\n        ```\n\n        If you get the error message below, you need to finetune the weights for your downstream task:\n\n        ```bash\n        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:\n        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated\n        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n        ```\n        \"\"\"\n        cache_dir = kwargs.pop(\"cache_dir\", DIFFUSERS_CACHE)\n        ignore_mismatched_sizes = kwargs.pop(\"ignore_mismatched_sizes\", False)\n        force_download = kwargs.pop(\"force_download\", False)\n        from_flax = kwargs.pop(\"from_flax\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        output_loading_info = kwargs.pop(\"output_loading_info\", False)\n        local_files_only = kwargs.pop(\"local_files_only\", HF_HUB_OFFLINE)\n        use_auth_token = kwargs.pop(\"use_auth_token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        torch_dtype = kwargs.pop(\"torch_dtype\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n        device_map = kwargs.pop(\"device_map\", None)\n        max_memory = kwargs.pop(\"max_memory\", None)\n        offload_folder = kwargs.pop(\"offload_folder\", None)\n        offload_state_dict = kwargs.pop(\"offload_state_dict\", False)\n        variant = kwargs.pop(\"variant\", None)\n        use_safetensors = kwargs.pop(\"use_safetensors\", None)\n\n        if use_safetensors and not is_safetensors_available():\n            raise ValueError(\n                \"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors\"\n            )\n\n        allow_pickle = False\n        if use_safetensors is None:\n            use_safetensors = is_safetensors_available()\n            allow_pickle = True\n\n        if device_map is not None and not is_accelerate_available():\n            raise NotImplementedError(\n                \"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set\"\n                \" `device_map=None`. You can install accelerate with `pip install accelerate`.\"\n            )\n\n        # Check if we can handle device_map and dispatching the weights\n        if device_map is not None and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `device_map=None`.\"\n            )\n\n        # Load config if we don't provide a configuration\n        config_path = pretrained_model_name_or_path\n\n        user_agent = {\n            \"diffusers\": __version__,\n            \"file_type\": \"model\",\n            \"framework\": \"pytorch\",\n        }\n\n        # load config\n        config, unused_kwargs, commit_hash = cls.load_config(\n            config_path,\n            cache_dir=cache_dir,\n            return_unused_kwargs=True,\n            return_commit_hash=True,\n            force_download=force_download,\n            resume_download=resume_download,\n            proxies=proxies,\n            local_files_only=local_files_only,\n            use_auth_token=use_auth_token,\n            revision=revision,\n            subfolder=subfolder,\n            device_map=device_map,\n            max_memory=max_memory,\n            offload_folder=offload_folder,\n            offload_state_dict=offload_state_dict,\n            user_agent=user_agent,\n            **kwargs,\n        )\n\n        # modify config\n        config[\"_class_name\"] = cls.__name__\n        config['in_channels'] = in_channels\n        config['out_channels'] = out_channels\n        config['sample_size'] = sample_size # training resolution\n        config['num_views'] = num_views\n        config['cd_attention_last'] = cd_attention_last\n        config['cd_attention_mid'] = cd_attention_mid\n        config['multiview_attention'] = multiview_attention\n        config['sparse_mv_attention'] = sparse_mv_attention\n        config['mvcd_attention'] = mvcd_attention\n        config[\"down_block_types\"] = [\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\"\n        ]\n        config['mid_block_type'] = \"UNetMidBlockMV2DCrossAttn\"\n        config[\"up_block_types\"] = [\n            \"UpBlock2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\",\n            \"CrossAttnUpBlockMV2D\"\n        ]\n\n        #config['class_embed_type'] = 'projection'\n        if camera_embedding_type == 'e_de_da_sincos':\n            config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6\n        else:\n            raise NotImplementedError\n\n        # load model\n        model_file = None\n        if from_flax:\n            raise NotImplementedError\n        else:\n            if use_safetensors:\n                try:\n                    model_file = _get_model_file(\n                        pretrained_model_name_or_path,\n                        weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),\n                        cache_dir=cache_dir,\n                        force_download=force_download,\n                        resume_download=resume_download,\n                        proxies=proxies,\n                        local_files_only=local_files_only,\n                        use_auth_token=use_auth_token,\n                        revision=revision,\n                        subfolder=subfolder,\n                        user_agent=user_agent,\n                        commit_hash=commit_hash,\n                    )\n                except IOError as e:\n                    if not allow_pickle:\n                        raise e\n                    pass\n            if model_file is None:\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path,\n                    weights_name=_add_variant(WEIGHTS_NAME, variant),\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    use_auth_token=use_auth_token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                    commit_hash=commit_hash,\n                )\n\n            model = cls.from_config(config, **unused_kwargs)\n            import copy\n            state_dict_v0 = load_state_dict(model_file, variant=variant)\n            state_dict = copy.deepcopy(state_dict_v0)\n            # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last\n            # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid\n            for key in state_dict_v0:\n                if 'attn_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint.\", \"attn_joint_last.\")] = state_dict.pop(tmp)\n                if 'norm_joint.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint.\", \"norm_joint_last.\")] = state_dict.pop(tmp)\n                if 'attn_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"attn_joint_twice.\", \"attn_joint_mid.\")] = state_dict.pop(tmp)\n                if 'norm_joint_twice.' in key:\n                    tmp = copy.deepcopy(key)\n                    state_dict[key.replace(\"norm_joint_twice.\", \"norm_joint_mid.\")] = state_dict.pop(tmp)\n            \n            model._convert_deprecated_attention_blocks(state_dict)\n\n            conv_in_weight = state_dict['conv_in.weight']\n            conv_out_weight = state_dict['conv_out.weight']\n            conv_out_bias = state_dict['conv_out.bias']\n            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(\n                model,\n                state_dict,\n                model_file,\n                pretrained_model_name_or_path,\n                ignore_mismatched_sizes=True,\n            )\n            # if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):\n            #     # initialize from the original SD structure\n            #     model.conv_in.weight.data[:,:4] = conv_in_weight\n\n            # # whether to place all zero to new layers?\n            # if zero_init_conv_in:\n            #     model.conv_in.weight.data[:,4:] = 0.\n\n            if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):\n                # initialize from the original SD structure\n                model.conv_out.weight.data[-4:, ] = conv_out_weight\n                model.conv_out.bias.data[-4:] = conv_out_bias\n                # model.conv_out.weight.data[:,:4] = conv_out_weight\n                # if out_channels == 8: # copy for the last 4 channels\n                #     model.conv_out.weight.data[:, 4:] = conv_out_weight\n                \n            if zero_init_camera_projection:\n                for p in model.class_embedding.parameters():\n                    torch.nn.init.zeros_(p)\n\n            loading_info = {\n                \"missing_keys\": missing_keys,\n                \"unexpected_keys\": unexpected_keys,\n                \"mismatched_keys\": mismatched_keys,\n                \"error_msgs\": error_msgs,\n            }\n\n        if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):\n            raise ValueError(\n                f\"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.\"\n            )\n        elif torch_dtype is not None:\n            model = model.to(torch_dtype)\n\n        model.register_to_config(_name_or_path=pretrained_model_name_or_path)\n\n        # Set model in evaluation mode to deactivate DropOut modules by default\n        model.eval()\n        if output_loading_info:\n            return model, loading_info\n\n        return model\n\n    @classmethod\n    def _load_pretrained_model_2d(\n        cls,\n        model,\n        state_dict,\n        resolved_archive_file,\n        pretrained_model_name_or_path,\n        ignore_mismatched_sizes=False,\n    ):\n        # Retrieve missing & unexpected_keys\n        model_state_dict = model.state_dict()\n        loaded_keys = list(state_dict.keys())\n\n        expected_keys = list(model_state_dict.keys())\n\n        original_loaded_keys = loaded_keys\n\n        missing_keys = list(set(expected_keys) - set(loaded_keys))\n        unexpected_keys = list(set(loaded_keys) - set(expected_keys))\n\n        # Make sure we are able to load base models as well as derived models (with heads)\n        model_to_load = model\n\n        def _find_mismatched_keys(\n            state_dict,\n            model_state_dict,\n            loaded_keys,\n            ignore_mismatched_sizes,\n        ):\n            mismatched_keys = []\n            if ignore_mismatched_sizes:\n                for checkpoint_key in loaded_keys:\n                    model_key = checkpoint_key\n\n                    if (\n                        model_key in model_state_dict\n                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape\n                    ):\n                        mismatched_keys.append(\n                            (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)\n                        )\n                        if 'proj_in' in checkpoint_key:\n                            state_dict[checkpoint_key] = torch.cat([state_dict[checkpoint_key], model_state_dict[checkpoint_key][:, -6:]], dim=1)\n                        else:\n                            del state_dict[checkpoint_key]\n            return mismatched_keys\n\n        if state_dict is not None:\n            # Whole checkpoint\n            mismatched_keys = _find_mismatched_keys(\n                state_dict,\n                model_state_dict,\n                original_loaded_keys,\n                ignore_mismatched_sizes,\n            )\n            error_msgs = _load_state_dict_into_model(model_to_load, state_dict)\n\n        if len(error_msgs) > 0:\n            error_msg = \"\\n\\t\".join(error_msgs)\n            if \"size mismatch\" in error_msg:\n                error_msg += (\n                    \"\\n\\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.\"\n                )\n            raise RuntimeError(f\"Error(s) in loading state_dict for {model.__class__.__name__}:\\n\\t{error_msg}\")\n\n        if len(unexpected_keys) > 0:\n            logger.warning(\n                f\"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when\"\n                f\" initializing {model.__class__.__name__}: {unexpected_keys}\\n- This IS expected if you are\"\n                f\" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task\"\n                \" or with another architecture (e.g. initializing a BertForSequenceClassification model from a\"\n                \" BertForPreTraining model).\\n- This IS NOT expected if you are initializing\"\n                f\" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly\"\n                \" identical (initializing a BertForSequenceClassification model from a\"\n                \" BertForSequenceClassification model).\"\n            )\n        else:\n            logger.info(f\"All model checkpoint weights were used when initializing {model.__class__.__name__}.\\n\")\n        if len(missing_keys) > 0:\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\\nYou should probably\"\n                \" TRAIN this model on a down-stream task to be able to use it for predictions and inference.\"\n            )\n        elif len(mismatched_keys) == 0:\n            logger.info(\n                f\"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path}.\\nIf your task is similar to the task the model of the\"\n                f\" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions\"\n                \" without further training.\"\n            )\n        if len(mismatched_keys) > 0:\n            mismatched_warning = \"\\n\".join(\n                [\n                    f\"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated\"\n                    for key, shape1, shape2 in mismatched_keys\n                ]\n            )\n            logger.warning(\n                f\"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at\"\n                f\" {pretrained_model_name_or_path} and are newly initialized because the shapes did not\"\n                f\" match:\\n{mismatched_warning}\\nYou should probably TRAIN this model on a down-stream task to be\"\n                \" able to use it for predictions and inference.\"\n            )\n\n        return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs"
  },
  {
    "path": "core/models_LGM_compos_diffusion.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport kiui\nfrom kiui.lpips import LPIPS\n\nfrom core.unet_LGM_compos import UNet\nfrom core.options_latents_diffusion import Options\nfrom core.gs import GaussianRenderer\nfrom diffusers import AutoencoderKL, DDPMScheduler,  UNet2DConditionModel\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom typing import Optional\nimport random\nimport torchvision.transforms.functional as TF\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n\nclass LGM(nn.Module):\n    def __init__(\n        self,\n        opt: Options,\n    ):\n        super().__init__()\n\n        self.opt = opt\n\n        # unet\n        self.unet = UNet(\n            9, 14, \n            down_channels=self.opt.down_channels,\n            down_attention=self.opt.down_attention,\n            mid_attention=self.opt.mid_attention,\n            up_channels=self.opt.up_channels,\n            up_attention=self.opt.up_attention,\n        )\n\n        # last conv\n        self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again\n\n        # Gaussian Renderer\n        self.gs = GaussianRenderer(opt)\n\n        # activations...\n        self.pos_act = lambda x: x.clamp(-1, 1)\n        self.scale_act = lambda x: 0.1 * F.softplus(x)\n        self.opacity_act = lambda x: torch.sigmoid(x)\n        self.rot_act = lambda x: F.normalize(x, dim=-1)\n        self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again\n\n        # LPIPS loss\n        if self.opt.lambda_lpips > 0:\n            self.lpips_loss = LPIPS(net='vgg')\n            self.lpips_loss.requires_grad_(False)\n        \n        model_key = opt.pretrained_model_name_or_path\n        self.unet2 = UNet2DConditionModel.from_pretrained(model_key, subfolder=\"unet\", low_cpu_mem_usage=False,device_map=None,ignore_mismatched_sizes=True)\n        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder=\"text_encoder\")\n        self.tokenizer =  CLIPTokenizer.from_pretrained(model_key, subfolder=\"tokenizer\")\n        self.scheduler = DDPMScheduler.from_pretrained(model_key, subfolder=\"scheduler\")\n        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder=\"vae\").to(self.opt.weight_dtype)\n        self.vae.requires_grad_(False)\n        self.unet2.requires_grad_(False)\n        #self.tokenizer.requires_grad_(False)\n        self.text_encoder.requires_grad_(False)\n\n    def state_dict(self, **kwargs):\n        # remove lpips_loss\n        state_dict = super().state_dict(**kwargs)\n        for k in list(state_dict.keys()):\n            if 'lpips_loss' in k:\n                del state_dict[k]\n        return state_dict\n\n    \n    def prepare_default_rays(self, device, elevation=0):\n        \n        from kiui.cam import orbit_camera\n        from core.utils import get_rays\n\n        cam_poses = np.stack([\n            orbit_camera(elevation, 0, radius=self.opt.cam_radius),\n            orbit_camera(elevation, 90, radius=self.opt.cam_radius),\n            orbit_camera(elevation, 180, radius=self.opt.cam_radius),\n            orbit_camera(elevation, 270, radius=self.opt.cam_radius),\n        ], axis=0) # [4, 4, 4]\n        cam_poses = torch.from_numpy(cam_poses)\n\n        rays_embeddings = []\n        for i in range(cam_poses.shape[0]):\n            rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]\n            rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]\n            rays_embeddings.append(rays_plucker)\n\n            ## visualize rays for plotting figure\n            # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)\n\n        rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]\n        \n        return rays_embeddings\n        \n\n    def forward_gaussians(self, images, encoder_hidden_states, data):\n        # images: [B, 4, 9, H, W]\n        # return: Gaussians: [B, dim_t]\n\n        B, V, C, H, W = images.shape\n        images = images.view(B*V, C, H, W)\n        timestep = data[\"timesteps\"].flatten(0, 1)\n        pred_noise, blocks_sample, temb= self.unet2(images, timestep, encoder_hidden_states, return_dict=False)\n        \n        pred_x0 = self.pred_x0(pred_noise, timestep, images)\n        images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5\n        images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False)\n        images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)\n        images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].flatten(0, 1).to(self.opt.weight_dtype) ], dim=1)\n\n        x = self.unet(images_256, blocks_sample, temb) # [B*4, 14, h, w]\n        x = self.conv(x) # [B*4, 14, h, w]\n\n        x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)\n        \n        ## visualize multi-view gaussian features for plotting figure\n        # tmp_alpha = self.opacity_act(x[0, :, 3:4])\n        # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)\n        # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5\n        # kiui.vis.plot_image(tmp_img_rgb, save=True)\n        # kiui.vis.plot_image(tmp_img_pos, save=True)\n\n        x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)\n        \n        pos = self.pos_act(x[..., 0:3]) # [B, N, 3]\n        opacity = self.opacity_act(x[..., 3:4])\n        scale = self.scale_act(x[..., 4:7])\n        rotation = self.rot_act(x[..., 7:11])\n        rgbs = self.rgb_act(x[..., 11:])\n\n        gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]\n        \n        return gaussians, images_512\n    \n    def pred_x0(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta: float=0.0,\n        verbose=False,\n    ):\n        \"\"\"\n        predict the sampe the next step in the denoise process.\n        \"\"\"\n        alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x.device)\n        alpha_prod_t = alphas_cumprod [timestep]\n\n        B = alpha_prod_t.shape[0]\n        alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1)\n        beta_prod_t = 1 - alpha_prod_t\n        \n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        return pred_x0\n    \n    def encode_prompt(\n            self,\n            prompt,\n            device,\n            prompt_embeds: Optional[torch.FloatTensor] = None,\n        ):\n            if prompt is not None and isinstance(prompt, str):\n                batch_size = 1\n            elif prompt is not None and isinstance(prompt, list):\n                batch_size = len(prompt)\n            else:\n                batch_size = prompt_embeds.shape[0]\n\n            if prompt_embeds is None:\n                text_inputs = self.tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=self.tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = self.tokenizer.batch_decode(\n                        untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                    )\n\n                if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                    attention_mask = text_inputs.attention_mask.to(device)\n                else:\n                    attention_mask = None\n\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n\n            if self.text_encoder is not None:\n                prompt_embeds_dtype = self.text_encoder.dtype\n            elif self.unet is not None:\n                prompt_embeds_dtype = self.unet.dtype\n            else:\n                prompt_embeds_dtype = prompt_embeds.dtype\n\n            prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            bs_embed, seq_len, _ = prompt_embeds.shape\n\n            return prompt_embeds\n    \n    def compute_snr(self, timesteps):\n        \"\"\"\n        Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849\n        \"\"\"\n        alphas_cumprod = self.scheduler.alphas_cumprod\n        sqrt_alphas_cumprod = alphas_cumprod**0.5\n        sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5\n\n        # Expand the tensors.\n        # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026\n        sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()\n        while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):\n            sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]\n        alpha = sqrt_alphas_cumprod.expand(timesteps.shape)\n\n        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()\n        while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):\n            sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]\n        sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)\n\n        # Compute SNR.\n        snr = (alpha / sigma) ** 2\n        return snr\n    \n    def forward(self, data, step_ratio=1):\n        # data: output of the dataloader\n        # return: loss\n\n        results = {}\n        loss = 0\n        start_idx = None\n        images = data['input'].to(self.opt.weight_dtype) # [B, 4, 9, h, W], input features\n        \n        num_views = images.shape[1]\n        #ray_embedding = images[:, :, 4:]\n        latents = images.flatten(0,1)\n        latent = latents[:,:4]\n        bsz, c, h, w = latent.shape\n       \n        # timesteps\n        timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bsz // num_views,), device=images.device)\n        timesteps_pred = timesteps.repeat_interleave(self.opt.num_views)\n        timesteps = timesteps.repeat_interleave(num_views)\n        timesteps = timesteps.long()\n        if(random.random() < 0.7):\n            start_idx = torch.randint(0,4, (1,)).item()\n            timesteps[start_idx ::num_views] = 0\n            timesteps_pred[start_idx ::self.opt.num_views] = 0\n            \n        if(random.random() < 0.7):\n            prompt = data[\"prompt\"]\n            \n            # prompt = [prompt[i][j] for j in range(len(prompt[0])) for i in range(len(prompt))]\n            # encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            prompt = [prompt[0][i] for i in range(len(prompt[0]))]\n            #print(prompt)\n            encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n            encoder_hidden_states = encoder_hidden_states.flatten(0,1)\n        else:\n            prompt = ['']*images.shape[0]\n            encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n            encoder_hidden_states = encoder_hidden_states.flatten(0,1)\n        \n        noise = torch.randn_like(latent).to(device=images.device)\n        noisy_latents = self.scheduler.add_noise(latent, noise, timesteps).to(device=images.device)\n        data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w)\n        data['timesteps'] = timesteps.reshape(bsz // num_views, num_views)\n\n        snr = self.compute_snr(timesteps_pred)\n        mse_loss_weights = torch.stack([snr, self.opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] \n        # use the first view to predict gaussians\n        images = data['noisy_latents']\n        gaussians, noise_images = self.forward_gaussians(images, encoder_hidden_states, data) # [B, N, 14]\n\n        results['gaussians'] = gaussians\n\n        # always use white bg\n        bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)\n        \n        # use the other views for rendering and supervision\n        results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)\n        pred_images = results['image'].to(self.opt.weight_dtype) # [B, V, C, output_size, output_size]\n        pred_alphas = results['alpha'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size]\n\n        results['images_pred'] = pred_images\n        results['alphas_pred'] = pred_alphas\n\n        gt_images = data['images2_output'].to(self.opt.weight_dtype) # [B, V, 3, output_size, output_size], ground-truth novel views\n        gt_masks = data['masks_output'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size], ground-truth masks\n\n        gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1).to(self.opt.weight_dtype) * (1 - gt_masks)\n\n        loss_mse_image = F.mse_loss(pred_images.flatten(0,1), gt_images.flatten(0,1), reduction=\"none\") \n        loss_mse_alpha = F.mse_loss(pred_alphas.flatten(0,1), gt_masks.flatten(0,1), reduction=\"none\")\n        loss_mse_image  = (loss_mse_image.mean(dim=list(range(1, len(loss_mse_image.shape)))) * mse_loss_weights).mean()\n        loss_mse_alpha = (loss_mse_alpha.mean(dim=list(range(1, len(loss_mse_alpha.shape)))) * mse_loss_weights).mean()\n        results['loss_mse_image'] = loss_mse_image\n        results['loss_mse_alpha'] = loss_mse_alpha\n        loss_mse = loss_mse_image + loss_mse_alpha\n        results['loss_mse'] = loss_mse\n        loss = loss + loss_mse\n        \n        results['gt_noise'] = noise_images.reshape(bsz // num_views, num_views, 3, 512, 512)\n\n        if self.opt.lambda_lpips > 0 and step_ratio > 0:\n            loss_lpips = self.lpips_loss(\n                # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,\n                # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,\n                # downsampled to at most 256 to reduce memory cost\n                F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), \n                F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),\n            )\n            lpips_loss_weights = torch.ones_like(mse_loss_weights)\n            if start_idx is not None:\n                lpips_loss_weights[start_idx::self.opt.num_views] = 5.0\n            loss_lpips = (loss_lpips.mean(dim=list(range(1, len(loss_lpips.shape)))) * lpips_loss_weights).mean()\n            results['loss_lpips'] = loss_lpips\n            #loss = loss + self.opt.lambda_lpips * (step_ratio-0.25) * loss_lpips\n            loss = loss + self.opt.lambda_lpips * loss_lpips\n            \n        results['loss'] = loss\n\n        # metric\n        with torch.no_grad():\n            psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))\n            results['psnr'] = psnr\n\n        return results\n"
  },
  {
    "path": "core/models_LGM_compos_diffusion_validate_inversion_2_masa.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport kiui\nfrom kiui.lpips import LPIPS\n\nfrom core.unet_LGM_compos import UNet\nfrom core.options_latents_diffusion import Options\nfrom core.gs import GaussianRenderer\nfrom diffusers import AutoencoderKL, DDPMScheduler,  UNet2DConditionModel, DDIMScheduler\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom typing import Optional\nimport random\nimport torchvision.transforms.functional as TF\nimport tqdm\nfrom core.control import ControlNetPipeline\nfrom core.masactrl import MutualSelfAttention3DControl\nfrom core.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers\n\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n\nclass LGM(nn.Module):\n    def __init__(\n        self,\n        opt: Options,\n    ):\n        super().__init__()\n\n        self.opt = opt\n\n        # unet\n        self.unet = UNet(\n            9, 14, \n            down_channels=self.opt.down_channels,\n            down_attention=self.opt.down_attention,\n            mid_attention=self.opt.mid_attention,\n            up_channels=self.opt.up_channels,\n            up_attention=self.opt.up_attention,\n        ).to(self.opt.weight_dtype)\n\n        # last conv\n        self.conv = nn.Conv2d(14, 14, kernel_size=1).to(self.opt.weight_dtype) # NOTE: maybe remove it if train again\n\n        # Gaussian Renderer\n        self.gs = GaussianRenderer(opt)\n\n        # activations...\n        self.pos_act = lambda x: x.clamp(-1, 1)\n        self.scale_act = lambda x: 0.1 * F.softplus(x)\n        self.opacity_act = lambda x: torch.sigmoid(x)\n        self.rot_act = lambda x: F.normalize(x, dim=-1)\n        self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again\n\n        # LPIPS loss\n        if self.opt.lambda_lpips > 0:\n            self.lpips_loss = LPIPS(net='vgg')\n            self.lpips_loss.requires_grad_(False)\n        \n        model_key = opt.pretrained_model_name_or_path\n\n        self.unet2 = UNet2DConditionModel.from_pretrained(model_key, subfolder=\"unet\", low_cpu_mem_usage=False,device_map=None,ignore_mismatched_sizes=True).to(self.opt.weight_dtype)\n        self.unet3 = UNet2DConditionModel.from_pretrained(model_key, subfolder=\"unet\", low_cpu_mem_usage=False,device_map=None,ignore_mismatched_sizes=True).to(self.opt.weight_dtype)\n        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder=\"text_encoder\").to(self.opt.weight_dtype)\n        self.tokenizer =  CLIPTokenizer.from_pretrained(model_key, subfolder=\"tokenizer\")\n        self.scheduler = DDPMScheduler.from_pretrained(model_key, subfolder=\"scheduler\")\n        self.scheduler2 = DDIMScheduler.from_pretrained(model_key, subfolder=\"scheduler\")\n        self.test_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n\n        \n        #self.pipe = MasaCtrlPipeline.from_pretrained(model_key, scheduler=self.test_scheduler)\n        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder=\"vae\").to(self.opt.weight_dtype)\n        self.vae.requires_grad_(False)\n        self.unet2.requires_grad_(False)\n        self.unet3.requires_grad_(False)\n        #self.tokenizer.requires_grad_(False)\n        self.text_encoder.requires_grad_(False)\n        self.steps = 2\n        self.layer = 10\n        #self.masa_editor = MutualSelfAttention3DControl(step, layer, total_steps=30)\n        self.base_editor = AttentionBase()\n\n    def state_dict(self, **kwargs):\n        # remove lpips_loss\n        state_dict = super().state_dict(**kwargs)\n        for k in list(state_dict.keys()):\n            if 'lpips_loss' in k:\n                del state_dict[k]\n        return state_dict\n\n    \n    def prepare_default_rays(self, device, elevation=0, proj_matrix=None):\n        \n        from kiui.cam import orbit_camera\n        from core.utils import get_rays\n\n        cam_poses = np.stack([\n            orbit_camera(elevation, 0, radius=self.opt.cam_radius, opengl=True),\n            orbit_camera(elevation, 90, radius=self.opt.cam_radius, opengl=True),\n            orbit_camera(elevation, 180, radius=self.opt.cam_radius, opengl=True),\n            orbit_camera(elevation, 270, radius=self.opt.cam_radius, opengl=True),\n        ], axis=0) # [4, 4, 4]\n        cam_poses = torch.from_numpy(cam_poses)\n\n        rays_embeddings = []\n        for i in range(cam_poses.shape[0]):\n            rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_ray_size, self.opt.input_ray_size, self.opt.fovy) # [h, w, 3]\n            rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]\n            rays_embeddings.append(rays_plucker)\n\n            ## visualize rays for plotting figure\n            # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)\n\n        rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]\n        cam_poses[:, :3, 1:3] *= -1\n        cam_poses = cam_poses.to(device)\n        cam_view = torch.inverse(cam_poses).transpose(1, 2)\n        cam_view_proj = cam_view @ proj_matrix\n        cam_pos = - cam_poses[:, :3, 3]\n\n        return rays_embeddings, cam_view, cam_view_proj, cam_pos\n        \n    \n    def prepare_default_rays_zero123(self, device, elevation=0, proj_matrix=None):\n        \n        from kiui.cam import orbit_camera\n        from core.utils import get_rays\n\n        cam_poses = np.stack([\n            orbit_camera(0, 0, radius=self.opt.cam_radius, opengl=True),\n            orbit_camera(-10, 90, radius=self.opt.cam_radius, opengl=True),\n            orbit_camera(-10, 210, radius=self.opt.cam_radius, opengl=True),\n            orbit_camera(20, 270, radius=self.opt.cam_radius, opengl=True),\n        ], axis=0) # [4, 4, 4]\n        # cam_poses = np.stack([\n        #     orbit_camera(0,  0, radius=self.opt.cam_radius, opengl=True),\n        #     orbit_camera(0, 120, radius=self.opt.cam_radius, opengl=True),\n        #     orbit_camera(0, 240, radius=self.opt.cam_radius, opengl=True),\n        #     orbit_camera(-30, 300, radius=self.opt.cam_radius, opengl=True),\n        # ], axis=0) # [4, 4, 4]\n        cam_poses = torch.from_numpy(cam_poses)\n\n        rays_embeddings = []\n        for i in range(cam_poses.shape[0]):\n            rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_ray_size, self.opt.input_ray_size, self.opt.fovy) # [h, w, 3]\n            rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]\n            rays_embeddings.append(rays_plucker)\n\n            ## visualize rays for plotting figure\n            # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)\n\n        rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]\n        cam_poses[:, :3, 1:3] *= -1\n        cam_poses = cam_poses.to(device)\n        cam_view = torch.inverse(cam_poses).transpose(1, 2)\n        cam_view_proj = cam_view @ proj_matrix\n        cam_pos = - cam_poses[:, :3, 3]\n\n        return rays_embeddings, cam_view, cam_view_proj, cam_pos\n    \n    def unet_step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep,\n        x: torch.FloatTensor,\n        eta: float=0.0,\n        verbose=False,\n    ):\n        \"\"\"\n        predict the sampe the next step in the denoise process.\n        \"\"\"\n        prev_timestep = timestep - self.test_scheduler.config.num_train_timesteps // self.test_scheduler.num_inference_steps\n        alpha_prod_t = self.test_scheduler.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.test_scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.test_scheduler.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output\n        x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir\n        return x_prev\n    \n    def forward_gaussians(self, images, encoder_hidden_states, data, uncon_encoder_hidden_states=None):\n        # images: [B, 4, 9, H, W]\n        # return: Gaussians: [B, dim_t]\n\n        B, V, C, H, W = images.shape\n        images = images.view(B*V, C, H, W)\n        timestep = data[\"timesteps\"].flatten(0, 1)\n        pred_noise, blocks_sample, temb= self.unet2(images, timestep, encoder_hidden_states, return_dict=False)\n        if uncon_encoder_hidden_states is not None:\n            uncon_pred_noise, _, _= self.unet3(images, timestep, uncon_encoder_hidden_states, return_dict=False)\n            pred_noise = uncon_pred_noise + 3 * (pred_noise - uncon_pred_noise)\n            # print(3.5)\n        if pred_noise.shape[0] == 5:\n            pred_x0 = self.pred_x0(pred_noise[:4], timestep[:4], images[:4])\n            masa_latent = self.unet_step(pred_noise[4:,], timestep[4].item(), images[4:])\n            temb = temb[:4]\n            blocks_sample = [i[:4] for i in blocks_sample]\n        else:\n            pred_x0 = self.pred_x0(pred_noise, timestep, images)\n        images_512 = (self.vae.decode(pred_x0.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5\n        images_256 = F.interpolate(images_512.clamp(0, 1), (256, 256), mode='bilinear', align_corners=False)\n        images_256 = TF.normalize(images_256, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)\n        images_256 = torch.cat([images_256.to(self.opt.weight_dtype), data['ray'].to(self.opt.weight_dtype) ], dim=1)\n\n        x = self.unet(images_256, blocks_sample, temb) # [B*4, 14, h, w]\n        x = self.conv(x) # [B*4, 14, h, w]\n\n        x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)\n        \n        ## visualize multi-view gaussian features for plotting figure\n        # tmp_alpha = self.opacity_act(x[0, :, 3:4])\n        # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)\n        # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5\n        # kiui.vis.plot_image(tmp_img_rgb, save=True)\n        # kiui.vis.plot_image(tmp_img_pos, save=True)\n\n        x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)\n        \n        pos = self.pos_act(x[..., 0:3]) # [B, N, 3]\n        opacity = self.opacity_act(x[..., 3:4])\n        scale = self.scale_act(x[..., 4:7])\n        rotation = self.rot_act(x[..., 7:11])\n        rgbs = self.rgb_act(x[..., 11:])\n\n        gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]\n        \n        return gaussians, images_512, masa_latent\n    \n    def pred_x0(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta: float=0.0,\n        verbose=False,\n    ):\n        \"\"\"\n        predict the sampe the next step in the denoise process.\n        \"\"\"\n        alphas_cumprod = self.test_scheduler.alphas_cumprod.to(device=x.device)\n        alpha_prod_t = alphas_cumprod [timestep]\n\n        B = alpha_prod_t.shape[0]\n        alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1)\n        beta_prod_t = 1 - alpha_prod_t\n        \n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        return pred_x0\n    \n    def step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta: float=0.0,\n        verbose=False,\n    ):\n        \"\"\"\n        predict the sampe the next step in the denoise process.\n        \"\"\"\n        prev_timestep = timestep - self.test_scheduler.config.num_train_timesteps // self.test_scheduler.num_inference_steps\n        prev_timestep[timestep==0] = 0\n        alphas_cumprod = self.test_scheduler.alphas_cumprod.to(device=x.device)\n        alpha_prod_t = alphas_cumprod [timestep]\n        #alpha_prod_t_prev = self.test_scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.test_scheduler.final_alpha_cumprod\n        alpha_prod_t_prev = torch.where(prev_timestep >0, self.test_scheduler.alphas_cumprod[prev_timestep], self.test_scheduler.final_alpha_cumprod).to(device=x.device)\n        B = alpha_prod_t.shape[0]\n        alpha_prod_t = alpha_prod_t.view(B, 1, 1, 1)\n        alpha_prod_t_prev = alpha_prod_t_prev.view(B, 1, 1, 1)\n        beta_prod_t = 1 - alpha_prod_t\n        \n        #pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_noise = (x - alpha_prod_t**0.5 * model_output) / beta_prod_t**0.5\n        \n        pred_dir = (1 - alpha_prod_t_prev)**0.5 * pred_noise\n        #x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir\n        x_prev = alpha_prod_t_prev**0.5 * model_output + pred_dir\n        return x_prev\n    \n    def encode_prompt(\n            self,\n            prompt,\n            device,\n            prompt_embeds: Optional[torch.FloatTensor] = None,\n        ):\n            if prompt is not None and isinstance(prompt, str):\n                batch_size = 1\n            elif prompt is not None and isinstance(prompt, list):\n                batch_size = len(prompt)\n            else:\n                batch_size = prompt_embeds.shape[0]\n\n            if prompt_embeds is None:\n                text_inputs = self.tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=self.tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = self.tokenizer.batch_decode(\n                        untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                    )\n\n                if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                    attention_mask = text_inputs.attention_mask.to(device)\n                else:\n                    attention_mask = None\n\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n\n            if self.text_encoder is not None:\n                prompt_embeds_dtype = self.text_encoder.dtype\n            elif self.unet is not None:\n                prompt_embeds_dtype = self.unet.dtype\n            else:\n                prompt_embeds_dtype = prompt_embeds.dtype\n\n            prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            bs_embed, seq_len, _ = prompt_embeds.shape\n\n            return prompt_embeds\n    \n    def compute_snr(self, timesteps):\n        \"\"\"\n        Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849\n        \"\"\"\n        alphas_cumprod = self.scheduler.alphas_cumprod\n        sqrt_alphas_cumprod = alphas_cumprod**0.5\n        sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5\n\n        # Expand the tensors.\n        # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026\n        sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()\n        while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):\n            sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]\n        alpha = sqrt_alphas_cumprod.expand(timesteps.shape)\n\n        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()\n        while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):\n            sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]\n        sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)\n\n        # Compute SNR.\n        snr = (alpha / sigma) ** 2\n        return snr\n    \n    def forward(self, data, step_ratio=1):\n        # data: output of the dataloader\n        # return: loss\n\n        results = {}\n        loss = 0\n        \n        images = data['input'].to(self.opt.weight_dtype) # [B, 4, 9, h, W], input features\n        \n        num_views = images.shape[1]\n        #ray_embedding = images[:, :, 4:]\n        latents = images.flatten(0,1)\n        latent = latents[:,:4]\n        bsz, c, h, w = latent.shape\n       \n        # timesteps\n        timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bsz // num_views,), device=images.device)\n        timesteps_pred = timesteps.repeat_interleave(self.opt.num_views)\n        timesteps = timesteps.repeat_interleave(num_views)\n        timesteps = timesteps.long()\n        if(random.random() < 0.7):\n            timesteps[::num_views] = 0\n            timesteps_pred[::self.opt.num_views] = 0\n            \n        if(random.random() < 0.7):\n            prompt = data[\"prompt\"]\n            \n            # prompt = [prompt[i][j] for j in range(len(prompt[0])) for i in range(len(prompt))]\n            # encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            prompt = [prompt[0][i] for i in range(len(prompt[0]))]\n            #print(prompt)\n            encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n            encoder_hidden_states = encoder_hidden_states.flatten(0,1)\n        else:\n            prompt = ['']*images.shape[0]\n            encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n            encoder_hidden_states = encoder_hidden_states.flatten(0,1)\n        \n        noise = torch.randn_like(latent).to(device=images.device)\n        noisy_latents = self.scheduler.add_noise(latent, noise, timesteps).to(device=images.device)\n        data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w)\n        data['timesteps'] = timesteps.reshape(bsz // num_views, num_views)\n\n        snr = self.compute_snr(timesteps_pred)\n        mse_loss_weights = torch.stack([snr, self.opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] \n        # use the first view to predict gaussians\n        images = data['noisy_latents']\n        gaussians, noise_images = self.forward_gaussians(images, encoder_hidden_states, data) # [B, N, 14]\n\n        results['gaussians'] = gaussians\n\n        # always use white bg\n        bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)\n        \n        # use the other views for rendering and supervision\n        results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)\n        pred_images = results['image'].to(self.opt.weight_dtype) # [B, V, C, output_size, output_size]\n        pred_alphas = results['alpha'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size]\n\n        results['images_pred'] = pred_images\n        results['alphas_pred'] = pred_alphas\n\n        gt_images = data['images2_output'].to(self.opt.weight_dtype) # [B, V, 3, output_size, output_size], ground-truth novel views\n        gt_masks = data['masks_output'].to(self.opt.weight_dtype) # [B, V, 1, output_size, output_size], ground-truth masks\n\n        gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1).to(self.opt.weight_dtype) * (1 - gt_masks)\n\n        loss_mse_image = F.mse_loss(pred_images.flatten(0,1), gt_images.flatten(0,1), reduction=\"none\") \n        loss_mse_alpha = F.mse_loss(pred_alphas.flatten(0,1), gt_masks.flatten(0,1), reduction=\"none\")\n        loss_mse_image  = (loss_mse_image.mean(dim=list(range(1, len(loss_mse_image.shape)))) * mse_loss_weights).mean()\n        loss_mse_alpha = (loss_mse_alpha.mean(dim=list(range(1, len(loss_mse_alpha.shape)))) * mse_loss_weights).mean()\n        results['loss_mse_image'] = loss_mse_image\n        results['loss_mse_alpha'] = loss_mse_alpha\n        loss_mse = loss_mse_image + loss_mse_alpha\n        results['loss_mse'] = loss_mse\n        loss = loss + loss_mse\n        \n        results['gt_noise'] = noise_images.reshape(bsz // num_views, num_views, 3, 512, 512)\n\n        if self.opt.lambda_lpips > 0 and step_ratio > 0:\n            loss_lpips = self.lpips_loss(\n                # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,\n                # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,\n                # downsampled to at most 256 to reduce memory cost\n                F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), \n                F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),\n            )\n            lpips_loss_weights = torch.ones_like(mse_loss_weights)\n            if timesteps[0] == 0:\n                lpips_loss_weights[::self.opt.num_views] = 5.0\n            loss_lpips = (loss_lpips.mean(dim=list(range(1, len(loss_lpips.shape)))) * lpips_loss_weights).mean()\n            results['loss_lpips'] = loss_lpips\n            #loss = loss + self.opt.lambda_lpips * (step_ratio-0.25) * loss_lpips\n            loss = loss + self.opt.lambda_lpips * loss_lpips\n            \n        results['loss'] = loss\n\n        # metric\n        with torch.no_grad():\n            psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))\n            results['psnr'] = psnr\n\n        return results\n    \n    def next_step(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n        eta=0.,\n        verbose=False\n    ):\n        \"\"\"\n        Inverse sampling for DDIM Inversion\n        \"\"\"\n        if verbose:\n            print(\"timestep: \", timestep)\n        next_step = timestep\n        timestep = min(timestep - self.test_scheduler.config.num_train_timesteps // self.test_scheduler.num_inference_steps, 999)\n        alpha_prod_t = self.test_scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.test_scheduler.final_alpha_cumprod\n        alpha_prod_t_next = self.test_scheduler.alphas_cumprod[next_step]\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output\n        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir\n        return x_next, pred_x0\n    \n    # def invert(self, image, encoder_hidden_states):\n    #     noisy_latent = image\n    #     for i, t in enumerate(reversed(self.test_scheduler.timesteps)):\n    #         model_inputs = noisy_latent\n    #         noise_pred = self.unet2(model_inputs, t, encoder_hidden_states=encoder_hidden_states).sample\n    #         noisy_latent, pred_x0 = self.next_step(noise_pred, t, noisy_latent)\n    #         a = (self.vae.decode(pred_x0.detach()/ 0.18215).sample +1)*0.5\n    #         b = a.clamp(0,1).float().reshape(8, 4, 3, 512, 512).detach().to(torch.float).cpu().numpy()\n    #         c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3)\n    #         kiui.write_image(f'{i}_2.jpg', c1)\n    #     return noisy_latent\n    \n    @torch.no_grad()\n    def image2latent(self, image):\n        #DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n        # if type(image) is Image:\n        #     image = np.array(image)\n        #     image = torch.from_numpy(image).float() / 127.5 - 1\n        #     image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)\n        # input image density range [-1, 1]\n        latents = self.vae.encode(image)['latent_dist'].mean\n        latents = latents * 0.18215\n        return latents\n    \n    @torch.no_grad()\n    def invert(\n        self,\n        image: torch.Tensor,\n        prompt=\"\",\n        # num_inference_steps=50,\n        # guidance_scale=7.5,\n        # eta=0.0,\n        # return_intermediates=False,\n        **kwds):\n        \"\"\"\n        invert a real image into noise map with determinisc DDIM inversion\n        \"\"\"\n        DEVICE = image.device\n        batch_size = image.shape[0]\n        # if isinstance(prompt, list):\n        #     if batch_size == 1:\n        #         image = image.expand(len(prompt), -1, -1, -1)\n        # elif isinstance(prompt, str):\n        #     if batch_size > 1:\n        #         prompt = [prompt] * batch_size\n        prompt = [prompt] * batch_size\n        # text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            return_tensors=\"pt\"\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]\n        print(\"input text embeddings :\", text_embeddings.shape)\n        # define initial latents\n        latents = self.image2latent(image)\n        start_latents = latents\n        # print(latents)\n        # exit()\n        # unconditional embedding for classifier free guidance\n        # if guidance_scale > 1.:\n        #     max_length = text_input.input_ids.shape[-1]\n        #     unconditional_input = self.tokenizer(\n        #         [\"\"] * batch_size,\n        #         padding=\"max_length\",\n        #         max_length=77,\n        #         return_tensors=\"pt\"\n        #     )\n        #     unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]\n        #     text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)\n\n        # print(\"latents shape: \", latents.shape)\n        # interative sampling\n        #self.scheduler.set_timesteps(num_inference_steps)\n        print(\"Valid timesteps: \", reversed(self.test_scheduler.timesteps))\n        # print(\"attributes: \", self.scheduler.__dict__)\n        latents_list = [latents]\n        pred_x0_list = [latents]\n        for i, t in enumerate(reversed(self.test_scheduler.timesteps)):\n            # if guidance_scale > 1.:\n            #     model_inputs = torch.cat([latents] * 2)\n            model_inputs = latents\n\n            # predict the noise\n            noise_pred = self.unet3(model_inputs, t, encoder_hidden_states=text_embeddings).sample\n            # if guidance_scale > 1.:\n            #     noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)\n            #     noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)\n            # compute the previous noise sample x_t-1 -> x_t\n            latents, pred_x0 = self.next_step(noise_pred, t, latents)\n            # a = (self.vae.decode(pred_x0.detach()/ 0.18215).sample +1)*0.5\n            # b = a.clamp(0,1).float().reshape(8, 4, 3, 512, 512).detach().to(torch.float).cpu().numpy()\n            # c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3)\n\n            a = (self.vae.decode(pred_x0[:1].detach()/ 0.18215).sample +1)*0.5\n            b = a.clamp(0,1).float().reshape(1, 3, 512, 512).detach().to(torch.float).cpu().numpy()\n            c1 = b.transpose(0, 2, 3, 1)\n            kiui.write_image(f'{self.opt.workspace}/{i}_7.jpg', c1)\n\n            latents_list.append(latents)\n            pred_x0_list.append(pred_x0)\n\n        # if return_intermediates:\n        #     # return the intermediate laters during inversion\n        #     # pred_x0_list = [self.latent2image(img, return_type=\"pt\") for img in pred_x0_list]\n        #     return latents, latents_list\n        return latents\n    \n    def validate(self, data, num_inference_steps=30, single_image=True):\n        results = {}\n        self.test_scheduler.set_timesteps(num_inference_steps)\n        self.opt.weight_dtype = torch.bfloat16\n        data['input'] =  self.vae.encode(data['images2_output']*2 -1).latent_dist.mode().detach() *0.18215\n        data['input'] = data['input'].unsqueeze(0)\n        images = data['input'].to(self.opt.weight_dtype) # [B, 4, 9, h, W], input features\n        \n        self.masa_editor = MutualSelfAttention3DControl(self.steps, self.layer, total_steps=num_inference_steps)\n        self.masa_editor.reset()\n        regiter_attention_editor_diffusers(self.unet2, self.masa_editor)\n        #self.test_scheduler = self.test_scheduler.to(images.device)\n\n        num_views = images.shape[1]\n        #ray_embedding = images[:, :, 4:]\n        latents = images.flatten(0,1)\n        latent = latents[:,:4]\n        bsz, c, h, w = latent.shape\n        \n        gt_images = data['images2_output'].to(self.opt.weight_dtype)\n        \n        noise = torch.randn_like(latent).to(device=images.device)\n        data['noisy_latents'] = noise.reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype)\n        \n        prompt = ['']*images.shape[0]\n        uncon_encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n        uncon_encoder_hidden_states = uncon_encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n        uncon_encoder_hidden_states = uncon_encoder_hidden_states.flatten(0,1)\n        \n        prompt = data[\"prompt\"]*(images.shape[0])\n            \n        encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n        encoder_hidden_states = encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n        encoder_hidden_states = encoder_hidden_states.flatten(0,1)\n        encoder_hidden_states[4:] = uncon_encoder_hidden_states[4:]\n\n        img_latent =  self.vae.encode(gt_images*2 -1).latent_dist.mode().detach() *0.18215\n        img = (self.vae.decode(img_latent.to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5\n        #img = gt_images\n        data['noisy_latents'] = self.invert(img*2-1).reshape(bsz // num_views, num_views, c, h, w)\n\n        # timesteps\n        # timesteps = torch.ones((bsz // num_views,), device=images.device)* 481\n        # timesteps_pred = timesteps.repeat_interleave(self.opt.num_views)\n        # timesteps = timesteps.repeat_interleave(num_views)\n        # timesteps = timesteps.long()\n        # # timesteps[::num_views] = 0\n        # # timesteps_pred[::self.opt.num_views] = 0\n        # # add noise\n        # noise = torch.randn_like(latent).to(device=images.device)\n        # noisy_latents = self.test_scheduler.add_noise(latent, noise, timesteps).to(device=images.device)\n        # data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w)\n        # data['timesteps'] = timesteps.reshape(bsz // num_views, num_views)\n        \n        if single_image is True:\n            data['noisy_latents'][:, :1] = images[:, :1]\n        for i, t in enumerate(self.test_scheduler.timesteps):\n            \n            print(i, t)\n            # timesteps = torch.ones((bsz // num_views,), device=images.device)* t\n            # timesteps_pred = timesteps.repeat_interleave(self.opt.num_views)\n            # timesteps = timesteps.repeat_interleave(num_views)\n            # timesteps = timesteps.long()\n            # # timesteps[::num_views] = 0\n            # # timesteps_pred[::self.opt.num_views] = 0\n            # # add noise\n            # noise = torch.randn_like(latent).to(device=images.device)\n            # noisy_latents = self.test_scheduler.add_noise(latent, noise, timesteps).to(device=images.device)\n            # data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w)\n            \n            timesteps = t.repeat(bsz // num_views)\n            #timesteps_pred = timesteps.repeat_interleave(self.opt.num_views)\n            timesteps = timesteps.repeat_interleave(num_views)\n            timesteps = timesteps.long()\n            #if(random.random() < 0.9):\n            if single_image is True:\n                timesteps[::num_views] = 0\n                #timesteps_pred[::self.opt.num_views] = 0\n\n            # add noise\n            # noise = torch.randn_like(latent).to(device=images.device)\n            # noisy_latents = self.scheduler.add_noise(latent, noise, timesteps).to(device=images.device)\n            # data['noisy_latents'] = noisy_latents.reshape(bsz // num_views, num_views, c, h, w)\n            data['timesteps'] = timesteps.reshape(bsz // num_views, num_views).to(device=images.device)\n            timesteps_cpu = timesteps.reshape(bsz // num_views, num_views)\n            ### FIXME\n            #timesteps_pred = torch.cat([data[\"timesteps\"], 300 * torch.ones(self.opt.batch_size, self.opt.num_views-data['timesteps'].shape[1]).long().to(timesteps.device)],dim=1).flatten(0,1)\n            #snr = self.compute_snr(timesteps_pred)\n            #mse_loss_weights = torch.stack([snr, opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] / snr\n            #mse_loss_weights = torch.stack([snr, self.opt.snr_gamma * torch.ones_like(timesteps_pred)], dim=1).min(dim=1)[0] \n            # use the first view to predict gaussians\n            # prompt = ['']*images.shape[0]\n            # uncon_encoder_hidden_states = self.encode_prompt(prompt, images.device).to(images.dtype)\n            # uncon_encoder_hidden_states = uncon_encoder_hidden_states[:,None].repeat(1,images.shape[1], 1, 1)\n            # uncon_encoder_hidden_states = uncon_encoder_hidden_states.flatten(0,1)\n            # uncon_encoder_hidden_states = None\n             \n            \n            images = data['noisy_latents']\n            # img = (self.vae.decode(images.flatten(0,1).to(self.opt.weight_dtype) / 0.18215).sample +1)*0.5\n            # b = img.unsqueeze(0).clamp(0,1).detach().to(torch.float).cpu().numpy() \n            # c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3)\n            # kiui.write_image(f'{self.opt.workspace}/{i}_1_noise.jpg', c1)\n            gaussians, noise_images, masa_latent = self.forward_gaussians(images, encoder_hidden_states, data, uncon_encoder_hidden_states) # [B, N, 14]\n\n            results['gaussians'] = gaussians\n\n            bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)\n            # use the other views for rendering and supervision\n            results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color, scale_modifier=1)\n            # pred_images = results['image'] # [B, V, C, output_size, output_size]\n            pred_alphas = results['alpha'].to(self.opt.weight_dtype)\n            pred_images = results['image'].to(self.opt.weight_dtype)\n            #pred_images = pred_images + self.white_latent.to(pred_images.device)*(1-pred_alphas)\n            \n            \n            #data['noisy_latents'] = self.step((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, timesteps_cpu.flatten(0, 1), data['noisy_latents'].flatten(0, 1)).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype)\n            data['noisy_latents'] = torch.cat([self.step((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, timesteps_cpu[:,:4].flatten(0, 1), data['noisy_latents'][:,:4].flatten(0, 1)), masa_latent]).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype)\n            # if t > 0:\n            #     #data['noisy_latents'] = self.step((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, timesteps_cpu.flatten(0, 1), noise).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype)\n            #     data['noisy_latents'] = self.scheduler.add_noise((self.vae.encode(pred_images[:,:4].flatten(0, 1)*2 -1).latent_dist.mode().detach())*0.18215, noise, timesteps-1).reshape(bsz // num_views, num_views, c, h, w).to(self.opt.weight_dtype)\n\n            if single_image is True:\n                data['noisy_latents'][:, :1] = images[:, :1]\n            \n            #a = (self.vae.decode(pred_images.detach().to(dtype=torch.bfloat16).flatten(0,1)/ 0.18215).sample +1)*0.5\n            b = pred_images.detach().to(torch.float).cpu().numpy() \n            c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3)\n            kiui.write_image(f'{self.opt.workspace}/{i}_2.jpg', c1)\n            \n            #a = (self.vae.decode(data['noisy_latents'].detach().flatten(0,1)/ 0.18215).sample +1)*0.5\n            b = noise_images.clamp(0,1).float().reshape(1, 4, 3, 512, 512).detach().to(torch.float).cpu().numpy()\n            c1 = b.transpose(0, 3, 1, 4, 2).reshape(-1, b.shape[1] * b.shape[3], 3)\n            kiui.write_image(f'{self.opt.workspace}/{i}_2_noise.jpg', c1)\n            \n\n        return results, gaussians\n    "
  },
  {
    "path": "core/options_latents_diffusion.py",
    "content": "import tyro\nfrom dataclasses import dataclass\nfrom typing import Tuple, Literal, Dict, Optional\n\n\n@dataclass\nclass Options:\n    ### model\n    # Unet image input size\n    gradient_checkpointing: bool = False\n    enable_xformers_memory_efficient_attention: bool = False\n    pretrained_model_name_or_path: str = \"/remote-home1/yeyang/aigc/model/stable-diffusion-v1-5\"\n    input_size: int = 256\n    input_ray_size: int = 256\n    # Unet definition\n    down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)\n    down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)\n    mid_attention: bool = True\n    up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)\n    up_attention: Tuple[bool, ...] = (True, True, True, False)\n    # Unet output size, dependent on the input_size and U-Net structure!\n    splat_size: int = 64\n    # gaussian render size\n    output_size: int = 256\n\n    ### dataset\n    # data mode (only support s3 now)\n    data_mode: Literal['s3'] = 's3'\n    data_path: str = '/remote-home1/yeyang/aigc/dataset2'\n    json_path:str = '/remote-home1/yeyang/aigc/dataset1'\n    # fovy of the dataset\n    fovy: float = 39.6\n    # camera near plane\n    znear: float = 0.01\n    # camera far plane\n    zfar: float = 1000\n    # number of all views (input + output)\n    num_views: int = 12\n    # number of views\n    num_input_views: int = 4\n    # camera radius\n    cam_radius: float = 1.5 # to better use [-1, 1]^3 space\n    # num workers\n    num_workers: int = 16\n    snr_gamma: int = 5\n    ### training\n    # workspace\n    workspace: str = './workspace'\n    workspace1: Optional[str] = None\n    # resume\n    resume: Optional[str] = None\n    # batch size (per-GPU)\n    batch_size: int = 8\n    # gradient accumulation\n    gradient_accumulation_steps: int = 1\n    # training epochs\n    num_epochs: int = 30\n    # lpips loss weight\n    lambda_lpips: float = 1.0 ##TZY\n    # gradient clip\n    gradient_clip: float = 1.0\n    # mixed precision\n    mixed_precision: str = 'bf16'\n    # learning rate\n    lr: float = 5e-5\n    # augmentation prob for grid distortion\n    prob_grid_distortion: float = 0.5\n    # augmentation prob for camera jitter\n    prob_cam_jitter: float = 0.5\n\n    ### testing\n    # test image path\n    test_path: Optional[str] = None\n\n    ### misc\n    # nvdiffrast backend setting\n    force_cuda_rast: bool = False\n    # render fancy video with gaussian scaling effect\n    fancy_video: bool = False\n    checkpoints_total_limit: int = 3\n\n# all the default settings\nconfig_defaults: Dict[str, Options] = {}\nconfig_doc: Dict[str, str] = {}\n\n# config_doc['lrm'] = 'the default settings for LGM'\n# config_defaults['lrm'] = Options()\n\nconfig_doc['small'] = 'small model with lower resolution Gaussians'\nconfig_defaults['small'] = Options(\n    input_size=256,\n    splat_size=64,\n    output_size=256,\n    batch_size=8,\n    gradient_accumulation_steps=1,\n    mixed_precision='bf16',\n)\n\nconfig_doc['big'] = 'big model with higher resolution Gaussians'\nconfig_defaults['big'] = Options(\n    input_size=64,\n    up_channels=(1024, 1024, 512, 256, 128), # one more decoder\n    up_attention=(True, True, True, False, False),\n    splat_size=32,\n    output_size=64, # render & supervise Gaussians at a higher resolution.\n    batch_size=96,\n    num_views=10,\n    gradient_accumulation_steps=1,\n    mixed_precision='bf16',\n)\n\nconfig_doc['big_latent'] = 'big model with higher resolution Gaussians'\nconfig_defaults['big_latent'] = Options(\n    input_size=64,\n    down_channels=(256, 512, 1024, 1024),\n    down_attention=(True, True, True, False),\n    up_channels=(1024, 1024, 512, 256),\n    up_attention=(False, True, True, True),\n    splat_size=64,\n    output_size=64, # render & supervise Gaussians at a higher resolution.\n    batch_size = 2, # 2\n    num_views= 8,\n    gradient_accumulation_steps= 6, # 16\n    mixed_precision='bf16',\n)\n\nconfig_doc['big_latent_sd'] = 'big model with higher resolution Gaussians'\nconfig_defaults['big_latent_sd'] = Options(\n    gradient_checkpointing = True,\n    enable_xformers_memory_efficient_attention = True,\n    lr = 1e-4, \n    #lambda_lpips = 0.5,\n    lambda_lpips = 2,\n    input_size=64,\n    down_channels=(256, 512, 1024, 1024),\n    down_attention=(True, True, True, False),\n    up_channels=(1024, 1024, 512, 256),\n    up_attention=(False, True, True, True),\n    splat_size=64,\n    output_size=64, # render & supervise Gaussians at a higher resolution.\n    batch_size = 2, # 2\n    num_views= 8,\n    gradient_accumulation_steps= 6, # 16\n    mixed_precision='bf16',\n)\nconfig_defaults['big_latent_sd_diffusion'] = Options(\n    gradient_checkpointing = True,\n    enable_xformers_memory_efficient_attention = True,\n    lr = 1e-4, \n    lambda_lpips = 0.5,\n    #lambda_lpips = 2,\n    input_size=64,\n    down_channels=(256, 512, 1024, 1024),\n    down_attention=(True, True, True, False),\n    up_channels=(1024, 1024, 512, 256),\n    up_attention=(False, True, True, True),\n    splat_size=64,\n    output_size=64, # render & supervise Gaussians at a higher resolution.\n    batch_size = 2, # 2\n    num_views= 8,\n    gradient_accumulation_steps= 2, # 16\n    mixed_precision='bf16',\n    num_epochs = 50,\n)\n\nconfig_defaults['big_latent_sd_diffusion_insert'] = Options(\n    gradient_checkpointing = True,\n    enable_xformers_memory_efficient_attention = True,\n    lr = 1e-4, \n    lambda_lpips = 0.5,\n    #lambda_lpips = 2,\n    input_size=64,\n    #resume= \"/remote-home1/yeyang/aigc/models/models--ashawkey--LGM/snapshots/1c28a2fd3bb1982414f722503ae862bdbb82636c/model_fp16_fixrot.safetensors\",\n    resume= 'workspace_1e-4_latent_diffusion_unet_LGM_insert3/model.safetensors',\n    up_channels=(1024, 1024, 512, 256, 128),\n    up_attention=(True, True, True, False, False),\n    splat_size=128,\n    output_size= 512, # render & supervise Gaussians at a higher resolution.\n    batch_size = 8, # 2\n    num_views= 8,\n    gradient_accumulation_steps= 1, # 16\n    mixed_precision='bf16',\n)\n\nconfig_defaults['big_latent_sd_diffusion_compose'] = Options(\n    gradient_checkpointing = True,\n    enable_xformers_memory_efficient_attention = True,\n    lr = 1e-4, \n    lambda_lpips = 0.5,\n    #lambda_lpips = 2,\n    input_size=64,\n    resume= \"/remote-home1/yeyang/aigc/models/models--ashawkey--LGM/snapshots/1c28a2fd3bb1982414f722503ae862bdbb82636c/model_fp16_fixrot.safetensors\",\n    #resume= 'workspace_1e-4_latent_diffusion_unet_LGM_compose_text/model.safetensors',\n    up_channels=(1024, 1024, 512, 256, 128),\n    up_attention=(True, True, True, False, False),\n    splat_size=128,\n    output_size= 512, # render & supervise Gaussians at a higher resolution.\n    batch_size = 8, # 2\n    num_views= 8,\n    gradient_accumulation_steps= 1, # 16\n    mixed_precision='bf16',\n)\n\nconfig_doc['big_latent_lpips'] = 'big model with higher resolution Gaussians'\nconfig_defaults['big_latent_lpips'] = Options(\n    input_size=64,\n    down_channels=(256, 512, 1024, 1024),\n    down_attention=(True, True, True, False),\n    up_channels=(1024, 1024, 512, 256),\n    up_attention=(False, True, True, True),\n    splat_size=64,\n    output_size=64, # render & supervise Gaussians at a higher resolution.\n    batch_size=6, # 2\n    num_views=10,\n    gradient_accumulation_steps=16, # 16\n    mixed_precision='bf16',\n)\n\n# config_doc['big_latent'] = 'big model with higher resolution Gaussians'\n# config_defaults['big_latent'] = Options(\n#     input_size=64,\n#     down_channels=(256, 512, 1024, 1024),\n#     down_attention=(True, True, True, False),\n#     up_channels=(1024, 1024, 512, 256),\n#     up_attention=(False, True, True, True),\n#     splat_size=64,\n#     output_size=64, # render & supervise Gaussians at a higher resolution.\n#     batch_size=15, # 2\n#     num_views=10,\n#     gradient_accumulation_steps=4, # 16\n#     mixed_precision='bf16',\n# )\n\nconfig_doc['tiny'] = 'tiny model for ablation'\nconfig_defaults['tiny'] = Options(\n    input_size=256, \n    down_channels=(32, 64, 128, 256, 512),\n    down_attention=(False, False, False, False, True),\n    up_channels=(512, 256, 128),\n    up_attention=(True, False, False, False),\n    splat_size=64,\n    output_size=256,\n    batch_size=16,\n    num_views=8,\n    gradient_accumulation_steps=1,\n    mixed_precision='bf16',\n)\n\nAllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)\n"
  },
  {
    "path": "core/provider_Gobjaverse_latent_diffusion_insert.py",
    "content": "import os\nimport cv2\nimport random\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms.functional as TF\nfrom torch.utils.data import Dataset\nimport json\nimport kiui\nfrom core.options_latents_diffusion import Options\nfrom core.utils import get_rays, grid_distortion, orbit_camera_jitter\nimport tyro\nfrom core.options import AllConfigs\n# import debugpy; debugpy.connect((\"localhost\", 5677)) \nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n\n\nclass GobjaverseDataset(Dataset):\n\n    def _warn(self):\n        raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')\n\n    def __init__(self, opt: Options, training=True):\n        \n        self.opt = opt\n        self.training = training\n\n        # TODO: remove this barrier\n        # self._warn()\n\n        # TODO: load the list of objects for training\n        self.items = []\n        with open('/remote-home1/yeyang/aigc/gobj_merged.json', 'r') as f:\n            self.items = json.load(f)\n\n        with open('/remote-home1/yeyang/aigc/text_captions_cap3d.json', 'r') as cap:\n            self.captions = json.load(cap)\n\n        # naive split\n        if self.training:\n            self.items = self.items[:-self.opt.batch_size]\n        else:\n            self.items = self.items[-self.opt.batch_size:]\n            #self.items = self.items[:self.opt.batch_size]\n        # default camera intrinsics\n        self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))\n        self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)\n        self.proj_matrix[0, 0] = 1 / self.tan_half_fov\n        self.proj_matrix[1, 1] = 1 / self.tan_half_fov\n        self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)\n        self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)\n        self.proj_matrix[2, 3] = 1\n\n\n    def __len__(self):\n        return len(self.items)\n        #return 250\n    \n    def __getitem__(self, idx):\n\n        uid = self.items[idx]\n        results = {}\n        results[\"prompt\"] = [self.captions[uid]] *self.opt.num_input_views\n        # load num_views images\n        images = []\n        images2 = []\n        masks = []\n        cam_poses = []\n        \n        vid_cnt = 0\n\n        # TODO: choose views, based on your rendering settings\n        if self.training:\n            # input views are in (36, 72), other views are randomly selected\n            #input = np.random.permutation(np.arange(27, 39))[:self.opt.num_input_views].tolist()\n            input_1 = np.random.permutation(np.arange(27, 30))[:1].tolist()\n            input_2 = np.random.permutation(np.arange(30, 33))[:1].tolist()\n            input_3 = np.random.permutation(np.arange(33, 36))[:1].tolist()\n            input_4 = np.random.permutation(np.arange(36, 39))[:1].tolist()\n            render = np.random.permutation(np.append(np.arange(1, 25), np.arange(27, 39))).tolist()\n            #vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist()’\n            vids = input_1 + input_2 + input_3 + input_4 + render\n        else:\n            # fixed views\n            vids = np.arange(27, 39, 4).tolist() + np.arange(1, 39).tolist()\n            #vids = [27, 30, 33, 36] + np.random.permutation(np.append(np.arange(1, 25), np.arange(27, 39))).tolist()\n            #vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist()\n        \n        for vid in vids:\n            #if not os.path.exists(os.path.join(self.opt.data_path, uid, f'{vid:05d}', f'{vid:05d}.pt')):\n            #uid = \"1/15039\"\n            image_path = os.path.join(self.opt.data_path, uid, f'{vid:05d}', f'{vid:05d}.pt')\n            #mask_path = os.path.join(self.opt.data_path, uid, f'{vid:05d}', f'{vid:05d}_mask.pt')\n            camera_path = os.path.join(self.opt.json_path, uid, f'{vid:05d}', f'{vid:05d}.json')\n            image2_path = os.path.join(self.opt.json_path, uid, f'{vid:05d}', f'{vid:05d}.png')\n          \n            try:\n                # TODO: load data (modify self.client here)\n                image2 = torch.from_numpy(cv2.imread(image2_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255)\n                image = torch.load(image_path)\n                #mask = torch.load(mask_path)\n                with open(camera_path, 'r', encoding='utf8') as f:\n                    meta = json.load(f)\n            except Exception as e:\n                print(f'[WARN] dataset {uid} {vid}: {e}')\n                continue\n            \n            # TODO: you may have a different camera system\n            # blender world + opencv cam --> opengl world & cam\n            c2w = np.eye(4)\n            c2w[:3, 0] = np.array(meta['x'])\n            c2w[:3, 1] = np.array(meta['y'])\n            c2w[:3, 2] = np.array(meta['z'])\n            c2w[:3, 3] = np.array(meta['origin'])\n            c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4)\n\n            c2w[1] *= -1\n            c2w[[1, 2]] = c2w[[2, 1]]\n            c2w[:3, 1:3] *= -1 # invert up and forward direction\n\n            # scale up radius to fully use the [-1, 1]^3 space!\n            #c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale\n          \n            image2 = image2.permute(2, 0, 1) # [4, 512, 512]\n            mask2 = image2[3:4] # [1, 512, 512]\n            image2 = image2[:3] * mask2 + (1 - mask2) # [3, 512, 512], to white bg\n            image2 = image2[[2,1,0]].contiguous() # bgr to rgb\n\n            images.append(image.squeeze(0).float()* 0.18215)\n            images2.append(image2)\n            masks.append(mask2.squeeze(0))\n            #masks.append(mask.squeeze(0).squeeze(0).to(image.dtype))\n            cam_poses.append(c2w)\n\n            vid_cnt += 1\n            if vid_cnt == self.opt.num_views:\n                break\n\n        if vid_cnt < self.opt.num_views:\n            print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')\n            n = self.opt.num_views - vid_cnt\n            images = images + [images[-1]] * n\n            images2 = images2 + [images2[-1]] * n\n            masks = masks + [masks[-1]] * n\n            cam_poses = cam_poses + [cam_poses[-1]] * n\n          \n        images = torch.stack(images, dim=0) # [V, C, H, W]\n        images2 = torch.stack(images2, dim=0) # [V, C, H, W]\n        masks = torch.stack(masks, dim=0) # [V, H, W]\n\n        # images = torch.randn(self.opt.num_views, 4, 64, 64).to(images.device)\n        # masks = torch.randn(self.opt.num_views, 64, 64).to(masks.device)\n\n        cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]\n        \n        radius = torch.norm(cam_poses[0, :3, 3])\n        cam_poses[:, :3, 3] *= self.opt.cam_radius / radius\n        # normalized camera feats as in paper (transform the first pose to a fixed position)\n        transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])\n        cam_poses = transform.unsqueeze(0) @ cam_poses  # [V, 4, 4]\n\n        images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]\n        cam_poses_input = cam_poses[:self.opt.num_input_views].clone()\n\n        # data augmentation\n        # if self.training:\n        #     # apply random grid distortion to simulate 3D inconsistency\n        #     if random.random() < self.opt.prob_grid_distortion:\n        #         images_input[1:] = grid_distortion(images_input[1:])\n        #     # apply camera jittering (only to input!)\n        #     if random.random() < self.opt.prob_cam_jitter:\n        #         cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:])\n\n        # images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)\n\n        # resize render ground-truth images, range still in [0, 1]\n        results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]\n        results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(512, 512), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]\n        results['images2_output'] = F.interpolate(images2, size=(512, 512), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]\n        \n        # build rays for input views\n        rays_embeddings = []\n        for i in range(self.opt.num_input_views):\n            rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_ray_size, self.opt.input_ray_size, self.opt.fovy) # [h, w, 3]\n            rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]\n            rays_embeddings.append(rays_plucker)\n\n     \n        rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]\n        #final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]\n        #results['input'] = final_input\n        results['input'] = images_input\n        results['ray'] = rays_embeddings\n        # opengl to colmap camera for gaussian renderer\n        cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction\n        \n        # cameras needed by gaussian rasterizer\n        cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]\n        cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]\n        cam_pos = - cam_poses[:, :3, 3] # [V, 3]\n        \n        results['cam_view'] = cam_view\n        results['cam_view_proj'] = cam_view_proj\n        results['cam_pos'] = cam_pos\n\n        return results\n    \nif __name__==\"__main__\":\n    opt = tyro.cli(AllConfigs)\n    GobjaverseDataset(opt, training=True)"
  },
  {
    "path": "core/unet_LGM_compos.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport numpy as np\nfrom typing import Tuple, Literal\nfrom functools import partial\n\nfrom core.attention import MemEffAttention, MemEffCrossAttention\n\nclass MVAttention(nn.Module):\n    def __init__(\n        self, \n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n        groups: int = 32,\n        eps: float = 1e-5,\n        residual: bool = True,\n        skip_scale: float = 1,\n        num_frames: int = 4, # WARN: hardcoded!\n    ):\n        super().__init__()\n\n        self.residual = residual\n        self.skip_scale = skip_scale\n        self.num_frames = num_frames\n\n        self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)\n        self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)\n\n    def forward(self, x):\n        # x: [B*V, C, H, W]\n        BV, C, H, W = x.shape\n        B = BV // self.num_frames # assert BV % self.num_frames == 0\n\n        res = x\n        x = self.norm(x)\n\n        x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)\n        x = self.attn(x)\n        x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)\n\n        if self.residual:\n            x = (x + res) * self.skip_scale\n        return x\n\nclass UnetAttention(nn.Module):\n    def __init__(\n        self, \n        dim: int,\n        dim_kv: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n        groups: int = 32,\n        eps: float = 1e-5,\n        residual: bool = True,\n        #skip_scale: float = 1,\n        num_frames: int = 4, # WARN: hardcoded!\n    ):\n        super().__init__()\n\n        self.residual = residual\n        self.skip_scale = 1\n        self.num_frames = num_frames\n\n        self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)\n        self.attn = MemEffCrossAttention(dim, dim, dim_kv, dim_kv, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)\n        \n        self.post_init()\n\n    def post_init(self):\n        nn.init.zeros_(self.attn.proj.weight.data)\n        nn.init.zeros_(self.attn.proj.bias.data)\n\n    def forward(self, x, unet_x):\n        # x: [B*V, C, H, W]\n        BV, C, H, W = x.shape\n        #B = BV // self.num_frames # assert BV % self.num_frames == 0\n\n        res = x\n        x = self.norm(x)\n\n        x = x.permute(0, 2, 3, 1).reshape(BV, -1, C)\n        unet_x = unet_x.permute(0, 2, 3, 1).reshape(BV, H*W, -1)\n        x = self.attn(x, unet_x, unet_x)\n        x = x.reshape(BV, H, W, C).permute(0, 3, 1, 2)\n\n        if self.residual:\n            x = (x + res) * self.skip_scale\n        return x\n    \nclass ResnetBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        resample: Literal['default', 'up', 'down'] = 'default',\n        groups: int = 32,\n        eps: float = 1e-5,\n        skip_scale: float = 1, # multiplied to output\n        temb_channels: int = 1280,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.skip_scale = skip_scale\n\n        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)\n        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        self.act = F.silu\n\n        self.resample = None\n        if resample == 'up':\n            self.resample = partial(F.interpolate, scale_factor=2.0, mode=\"nearest\")\n        elif resample == 'down':\n            self.resample = nn.AvgPool2d(kernel_size=2, stride=2)\n        \n        self.shortcut = nn.Identity()\n        if self.in_channels != self.out_channels:\n            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)\n        \n        self.time_emb_proj = nn.Linear(temb_channels, out_channels)\n        self.nolinearity = F.silu\n\n    def post_init(self):\n        nn.init.zeros_(self.time_emb_proj.weight.data)\n        nn.init.zeros_(self.time_emb_proj.bias.data)\n    \n    def forward(self, x, temb=None):\n        res = x\n\n        x = self.norm1(x)\n        x = self.act(x)\n\n        if self.resample:\n            res = self.resample(res)\n            x = self.resample(x)\n        \n        x = self.conv1(x)\n        if temb is not None:\n            temb = self.nolinearity(temb)\n            temb = self.time_emb_proj(temb)[:, :, None, None]\n            x = x + temb\n        x = self.norm2(x)\n        x = self.act(x)\n        x = self.conv2(x)\n\n        x = (x + self.shortcut(res)) * self.skip_scale\n\n        return x\n\nclass DownBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        unet_out_channels: int,\n        unet_out_next_channels: int,\n        num_layers: int = 1,\n        downsample: bool = True,\n        attention: bool = True,\n        unet_attention: bool = False,\n        attention_heads: int = 16,\n        skip_scale: float = 1,\n    ):\n        super().__init__()\n \n        nets = []\n        attns = []\n        unet_attns = []\n        self.unet_attention = unet_attention\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))\n            if attention:\n                attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))\n            else:\n                attns.append(None)\n            if unet_attention:\n                unet_attns.append(UnetAttention(out_channels, unet_out_channels))\n            else:\n                unet_attns.append(None)\n\n        if unet_attention and downsample:\n            self.down_unet_attns = UnetAttention(out_channels, unet_out_next_channels)\n\n        self.nets = nn.ModuleList(nets)\n        self.attns = nn.ModuleList(attns)\n        self.unet_attns = nn.ModuleList(unet_attns)\n\n        self.downsample = None\n        if downsample:\n            self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)\n\n    def forward(self, x, unet_xs=None, temb=None):\n        xs = []\n\n        for attn, unet_attn, net in zip(self.attns, self.unet_attns, self.nets):\n            x = net(x, temb)\n            if attn:\n                x = attn(x)\n            if unet_attn:\n                unet_x = unet_xs[0]\n                unet_xs = unet_xs[1:]\n                x = unet_attn(x, unet_x)\n            xs.append(x)\n\n        if self.downsample:\n            x = self.downsample(x)\n            if unet_attn:\n                unet_x = unet_xs[0]\n                unet_xs = unet_xs[1:]\n                x = self.down_unet_attns(x, unet_x)\n            xs.append(x)\n  \n        return x, xs\n\n\nclass MidBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        num_layers: int = 1,\n        attention: bool = True,\n        attention_heads: int = 16,\n        skip_scale: float = 1,\n    ):\n        super().__init__()\n\n        nets = []\n        attns = []\n        # first layer\n        nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))\n        # more layers\n        for i in range(num_layers):\n            nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))\n            if attention:\n                attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale))\n            else:\n                attns.append(None)\n        self.nets = nn.ModuleList(nets)\n        self.attns = nn.ModuleList(attns)\n        \n    def forward(self, x, temb=None):\n        x = self.nets[0](x, temb)\n        for attn, net in zip(self.attns, self.nets[1:]):\n            if attn:\n                x = attn(x)\n            x = net(x, temb)\n        return x\n\n\nclass UpBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_out_channels: int,\n        out_channels: int,\n        num_layers: int = 1,\n        upsample: bool = True,\n        attention: bool = True,\n        attention_heads: int = 16,\n        skip_scale: float = 1,\n    ):\n        super().__init__()\n\n        nets = []\n        attns = []\n        for i in range(num_layers):\n            cin = in_channels if i == 0 else out_channels\n            cskip = prev_out_channels if (i == num_layers - 1) else out_channels\n\n            nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))\n            if attention:\n                attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))\n            else:\n                attns.append(None)\n        self.nets = nn.ModuleList(nets)\n        self.attns = nn.ModuleList(attns)\n\n        self.upsample = None\n        if upsample:\n            self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n    def forward(self, x, xs, temb=None):\n\n        for attn, net in zip(self.attns, self.nets):\n            res_x = xs[-1]\n            xs = xs[:-1]\n            x = torch.cat([x, res_x], dim=1)\n            x = net(x, temb)\n            if attn:\n                x = attn(x)\n            \n        if self.upsample:\n            x = F.interpolate(x, scale_factor=2.0, mode='nearest')\n            x = self.upsample(x)\n\n        return x\n\n\n# it could be asymmetric!\nclass UNet(nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),\n        down_unet_channels: Tuple[int, ...] = (320, 320, 320, 640, 1280, 1280),\n        down_attention: Tuple[bool, ...] = (False, False, False, True, True),\n        down_unet_attention : Tuple[bool, ...] = (False, False, True, True, True, True),\n        mid_attention: bool = True,\n        #mid_unet_attention: bool = True,\n        up_channels: Tuple[int, ...] = (1024, 512, 256),\n        #up_unet_channels: Tuple[int, ...] = (1280, 1280, 640, 320, 320),\n        up_attention: Tuple[bool, ...] = (True, True, False),\n        #up_unet_attention: Tuple[bool, ...] = (True, True, True, True, False),\n        #up_last_unet_attention: Tuple[bool, ...] = (False, False, False, True, False),\n        layers_per_block: int = 2,\n        skip_scale: float = np.sqrt(0.5),\n    ):\n        super().__init__()\n\n        # first\n        self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)\n\n        # down\n        down_blocks = []\n        cout = down_channels[0]\n        for i in range(len(down_channels)):\n            cin = cout\n            cout = down_channels[i]\n            unet_cout = down_unet_channels[i]\n            unet_next_cout = down_unet_channels[i+1] if i != len(down_channels) - 1 else down_unet_channels[i]\n            down_blocks.append(DownBlock(\n                cin, cout, unet_cout, unet_next_cout,\n                num_layers=layers_per_block, \n                downsample=(i != len(down_channels) - 1), # not final layer\n                attention=down_attention[i],\n                unet_attention = down_unet_attention[i],\n                skip_scale=skip_scale,\n            ))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n        # mid\n        self.mid_block = MidBlock(down_channels[-1],  attention=mid_attention, skip_scale=skip_scale)\n\n        # up\n        up_blocks = []\n        cout = up_channels[0]\n        for i in range(len(up_channels)):\n            cin = cout\n            cout = up_channels[i]\n            cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric\n            #unet_cout = up_unet_channels[i]\n\n            up_blocks.append(UpBlock(\n                cin, cskip, cout, \n                num_layers=layers_per_block + 1, # one more layer for up\n                upsample=(i != len(up_channels) - 1), # not final layer\n                attention=up_attention[i],\n                #unet_attention = up_unet_attention[i],\n                #last_unet_attention = up_last_unet_attention[i],\n                skip_scale=skip_scale,\n            ))\n        self.up_blocks = nn.ModuleList(up_blocks)\n\n        # last\n        self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)\n        self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)\n\n\n    def forward(self, x, unet_xss=None, temb=None):\n        # x: [B, Cin, H, W]\n\n        # first\n        x = self.conv_in(x)\n        \n        # down\n        xss = [x]\n        unet_xss = unet_xss[::-1]\n        for block in self.down_blocks:\n            if block.unet_attention == True:\n                length = len(block.unet_attns) + 1 if block.downsample else len(block.unet_attns)\n                unet_xs = unet_xss[:length]\n                unet_xss = unet_xss[length:]\n                x, xs= block(x, unet_xs, temb)\n            else:\n                x, xs = block(x, temb)\n            xss.extend(xs)\n        \n        # mid\n        # if self.mid_block.unet_attention == True:\n        #     unet_xs = unet_xss[0]\n        #     unet_xss = unet_xss[1:]\n        #     x = self.mid_block(x, unet_xs, temb)\n        #else:\n        x = self.mid_block(x, temb)\n\n        # up\n        for block in self.up_blocks:\n            xs = xss[-len(block.nets):]\n            xss = xss[:-len(block.nets)]\n            # if block.unet_attention == True:\n            #     length = len(block.unet_attns) + 1 if block.upsample else len(block.unet_attns)\n            #     unet_xs = unet_xss[:length]\n            #     unet_xss = unet_xss[length:]\n            #     x = block(x, xs, unet_xs, temb)\n            #else:\n            x = block(x, xs, temb)\n\n        # last\n        x = self.norm_out(x)\n        x = F.silu(x)\n        x = self.conv_out(x) # [B, Cout, H', W']\n\n        return x\n"
  },
  {
    "path": "core/utils.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport roma\nfrom kiui.op import safe_normalize\n\ndef get_rays(pose, h, w, fovy, opengl=True):\n\n    x, y = torch.meshgrid(\n        torch.arange(w, device=pose.device),\n        torch.arange(h, device=pose.device),\n        indexing=\"xy\",\n    )\n    x = x.flatten()\n    y = y.flatten()\n\n    cx = w * 0.5\n    cy = h * 0.5\n\n    focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))\n\n    camera_dirs = F.pad(\n        torch.stack(\n            [\n                (x - cx + 0.5) / focal,\n                (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),\n            ],\n            dim=-1,\n        ),\n        (0, 1),\n        value=(-1.0 if opengl else 1.0),\n    )  # [hw, 3]\n\n    rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1)  # [hw, 3]\n    rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]\n\n    rays_o = rays_o.view(h, w, 3)\n    rays_d = safe_normalize(rays_d).view(h, w, 3)\n\n    return rays_o, rays_d\n\ndef orbit_camera_jitter(poses, strength=0.1):\n    # poses: [B, 4, 4], assume orbit camera in opengl format\n    # random orbital rotate\n\n    B = poses.shape[0]\n    rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)\n    rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)\n\n    rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)\n    R = rot @ poses[:, :3, :3]\n    T = rot @ poses[:, :3, 3:]\n\n    new_poses = poses.clone()\n    new_poses[:, :3, :3] = R\n    new_poses[:, :3, 3:] = T\n    \n    return new_poses\n\ndef grid_distortion(images, strength=0.5):\n    # images: [B, C, H, W]\n    # num_steps: int, grid resolution for distortion\n    # strength: float in [0, 1], strength of distortion\n\n    B, C, H, W = images.shape\n\n    num_steps = np.random.randint(8, 17)\n    grid_steps = torch.linspace(-1, 1, num_steps)\n\n    # have to loop batch...\n    grids = []\n    for b in range(B):\n        # construct displacement\n        x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive\n        x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb\n        x_steps = (x_steps * W).long() # [num_steps]\n        x_steps[0] = 0\n        x_steps[-1] = W\n        xs = []\n        for i in range(num_steps - 1):\n            xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))\n        xs = torch.cat(xs, dim=0) # [W]\n\n        y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive\n        y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb\n        y_steps = (y_steps * H).long() # [num_steps]\n        y_steps[0] = 0\n        y_steps[-1] = H\n        ys = []\n        for i in range(num_steps - 1):\n            ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))\n        ys = torch.cat(ys, dim=0) # [H]\n\n        # construct grid\n        grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]\n        grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]\n\n        grids.append(grid)\n    \n    grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]\n\n    # grid sample\n    images = F.grid_sample(images, grids, align_corners=False)\n\n    return images\n\n"
  },
  {
    "path": "infer_ours_masa.py",
    "content": "\nimport os\nimport tyro\nimport glob\nimport imageio\nimport numpy as np\nimport tqdm\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms.functional as TF\nfrom safetensors.torch import load_file\nimport rembg\n\nimport kiui\nfrom kiui.op import recenter\nfrom kiui.cam import orbit_camera\n\nfrom core.options_latents_diffusion import AllConfigs, Options\nfrom core.models_LGM_compos_diffusion_validate_inversion_2_masa import LGM\nimport cv2\nfrom mvdream.pipeline_mvdream import MVDreamPipeline\n#import debugpy; debugpy.connect((\"localhost\", 5999)) \nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n\nopt = tyro.cli(AllConfigs)\nopt.weight_dtype = torch.bfloat16\n# model\nmodel = LGM(opt)\n\n# resume pretrained checkpoint\nif opt.resume is not None:\n    if opt.resume.endswith('safetensors'):\n        ckpt = load_file(opt.resume, device='cpu')\n    else:\n        ckpt = torch.load(opt.resume, map_location='cpu')\n    model.load_state_dict(ckpt, strict=False)\n    print(f'[INFO] Loaded checkpoint from {opt.resume}')\nelse:\n    print(f'[WARN] model randomly initialized, are you sure?')\n\n# device\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = model.half().to(device)\nmodel.eval()\n\n\ntan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))\nproj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)\nproj_matrix[0, 0] = 1 / tan_half_fov\nproj_matrix[1, 1] = 1 / tan_half_fov\nproj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)\nproj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)\nproj_matrix[2, 3] = 1\n\nrays_embeddings, input_cam_view,  input_cam_view_proj, input_cam_pos,= model.prepare_default_rays(device, proj_matrix=proj_matrix)\n\n# load image dream\npipe = MVDreamPipeline.from_pretrained(\n    \"/remote-home1/yeyang/aigc/models/models--ashawkey--imagedream-ipmv-diffusers/snapshots/73a034178e748421506492e91790cc62d6aefef5\", # remote weights\n    torch_dtype=torch.float16,\n    trust_remote_code=True,\n    # local_files_only=True,\n)\npipe = pipe.to(device)\n\n# load rembg\nbg_remover = rembg.new_session()\n\n# process function\ndef process(opt: Options, path):\n    name = os.path.splitext(os.path.basename(path))[0]\n    print(f'[INFO] Processing {path} --> {name}')\n    os.makedirs(opt.workspace, exist_ok=True)\n\n    input_image = kiui.read_image(path, mode='uint8')\n\n    # bg removal\n    carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]\n    mask = carved_image[..., -1] > 0\n    # carved_image = input_image \n    # mask = carved_image[..., -1] > 0\n    # recenter\n    image = recenter(carved_image, mask, border_ratio=0.2)\n    \n    # generate mv\n    image = image.astype(np.float32) / 255.0\n\n    # rgba to rgb white bg\n    if image.shape[-1] == 4:\n        image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])\n\n    mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0)\n    # grid = np.concatenate(\n    #     [\n    #         np.concatenate([mv_image[0], mv_image[2]], axis=0),\n    #         np.concatenate([mv_image[1], mv_image[3]], axis=0),\n    #     ],\n    #     axis=1,\n    # )\n    #kiui.write_image(os.path.join(opt.workspace, 'sparrow1.jpg'), image)\n    # image_2 = kiui.read_image('workspace_test_LGM_ours_shoe_masa_cfg3/mv_image.jpg', mode='uint8')\n    # top_left = image_2[:256, :256, :]/255\n    # top_right = image_2[:256, 256:, :]/255\n    # bottom_left = image_2[256:, :256, :]/255\n    # bottom_right = image_2[256:, 256:, :]/255\n    # mv_image = np.stack([top_left, top_right, bottom_left, bottom_right], axis=0)\n    grid = np.concatenate(\n        [\n            np.concatenate([mv_image[0], mv_image[2]], axis=0),\n            np.concatenate([mv_image[1], mv_image[3]], axis=0),\n        ],\n        axis=1,\n    )\n    kiui.write_image(os.path.join(opt.workspace, 'mv_image.jpg'), grid)\n    #kiui.write_image(os.path.join('data_test2', 'helmet.png'), mv_image[1])\n    #image_2 = cv2.resize(image, (256, 256))\n    image_2 = cv2.resize(image, (512, 512))\n    #kiui.write_image(os.path.join(opt.workspace, 'dragon.png'), image_2)\n    image_2 = torch.from_numpy(image_2).unsqueeze(0).permute(0, 3, 1, 2).float().to(device)\n    #mv_image = np.stack([image_2, mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32\n    mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3]], axis=0)\n\n    ref_image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).float().to(device)\n    ref_image = F.interpolate(ref_image, size=(512, 512), mode='bilinear', align_corners=False)\n    # generate gaussians\n    input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]\n    input_image = F.interpolate(input_image, size=(512, 512), mode='bilinear', align_corners=False)\n    #input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)\n    input_image = torch.cat([image_2, input_image, ref_image])\n    #input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]\n    \n    with torch.no_grad():\n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n            # pass\n            # generate gaussians\n            data={}\n            data['images2_output'] = input_image\n            data['cam_view'] = input_cam_view.unsqueeze(0)\n            data['cam_view_proj'] = input_cam_view_proj.unsqueeze(0)\n            data['cam_pos'] = input_cam_pos.unsqueeze(0)\n            data['ray'] = rays_embeddings\n            data['prompt'] = \"a photo of a shoe\"\n            #gaussians = model.forward_gaussians(input_image)\n            results, gaussians = model.validate(data)\n        # save gaussians\n        model.gs.save_ply(gaussians, os.path.join(opt.workspace, name + '.ply'))\n        #gaussians = model.gs.load_ply(\"workspace_test_LGM_ours_5/anya_rgba.ply\").unsqueeze(0)\n        # render 360 video \n        images = []\n        elevation = 0\n\n        if opt.fancy_video:\n\n            azimuth = np.arange(0, 720, 4, dtype=np.int32)\n            for azi in tqdm.tqdm(azimuth):\n                \n                cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)\n\n                cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction\n                \n                # cameras needed by gaussian rasterizer\n                cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]\n                cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]\n                cam_pos = - cam_poses[:, :3, 3] # [V, 3]\n\n                scale = min(azi / 360, 1)\n\n                image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']\n                images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))\n        else:\n            azimuth = np.arange(0, 360, 2, dtype=np.int32)\n            for azi in tqdm.tqdm(azimuth):\n                \n                cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)\n\n                cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction\n                \n                # cameras needed by gaussian rasterizer\n                cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]\n                cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]\n                cam_pos = - cam_poses[:, :3, 3] # [V, 3]\n\n                image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']\n                images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))\n                imageio.imsave(os.path.join(opt.workspace, f'{azi}' + '.png'), images[-1][0])\n                \n        images = np.concatenate(images, axis=0)\n        imageio.mimwrite(os.path.join(opt.workspace, name + '.mp4'), images, fps=30)\n \n\nassert opt.test_path is not None\nif os.path.isdir(opt.test_path):\n    file_paths = glob.glob(os.path.join(opt.test_path, \"*\"))\nelse:\n    file_paths = [opt.test_path]\nfor path in file_paths:\n    process(opt, path)\n"
  },
  {
    "path": "main_resume_compose.py",
    "content": "import tyro\nimport time\nimport random\n\nimport torch\nfrom core.options_latents_diffusion import AllConfigs\nfrom core.models_LGM_compos_diffusion import LGM\nfrom accelerate import Accelerator, DistributedDataParallelKwargs\nfrom safetensors.torch import load_file\nfrom torch.utils.tensorboard import SummaryWriter\nimport kiui\nfrom diffusers.utils.import_utils import is_xformers_available\nimport os\nimport shutil\n\n\ndef main():    \n    opt = tyro.cli(AllConfigs)\n    \n    writer = SummaryWriter(f'{opt.workspace}/runs')\n    # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    print(opt.pretrained_model_name_or_path)\n    accelerator = Accelerator(\n        mixed_precision=opt.mixed_precision,\n        gradient_accumulation_steps=opt.gradient_accumulation_steps,\n        # kwargs_handlers=[ddp_kwargs],\n    )\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n        opt.mixed_precision = accelerator.mixed_precision\n        opt.weight_dtype = weight_dtype\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        opt.weight_dtype = weight_dtype\n        opt.mixed_precision = accelerator.mixed_precision\n\n    # model\n    model = LGM(opt)\n    # vae = model.vae\n    # text_encoder = model.text_encoder\n    # text_encoder.requires_grad_(False)\n    # vae.requires_grad_(False)\n    \n    unet = model.unet\n    conv = model.conv\n    unet.requires_grad_(True)\n    conv.requires_grad_(True)\n    unet2 = model.unet2\n\n    if opt.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n            unet2.enable_xformers_memory_efficient_attention()\n    \n            \n    if opt.gradient_checkpointing:\n        unet2.enable_gradient_checkpointing()\n    \n    # resume\n    if opt.resume is not None:\n        if opt.resume.endswith('safetensors'):\n            ckpt = load_file(opt.resume, device='cpu')\n        else:\n            ckpt = torch.load(opt.resume, map_location='cpu')\n        \n        # tolerant load (only load matching shapes)\n        # model.load_state_dict(ckpt, strict=False)\n        state_dict = model.state_dict()\n        for k, v in ckpt.items():\n            if k in state_dict: \n                if state_dict[k].shape == v.shape:\n                    state_dict[k].copy_(v)\n                else:\n                    accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')\n            else:\n                accelerator.print(f'[WARN] unexpected param {k}: {v.shape}')\n    \n    # data\n    if opt.data_mode == 's3':\n        from core.provider_Gobjaverse_latent_diffusion_insert import GobjaverseDataset as Dataset\n    else:\n        raise NotImplementedError\n\n    train_dataset = Dataset(opt, training=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=opt.batch_size,\n        shuffle=True,\n        num_workers=opt.num_workers,\n        pin_memory=True,\n        drop_last=True,\n    )\n\n    test_dataset = Dataset(opt, training=False)\n    test_dataloader = torch.utils.data.DataLoader(\n        test_dataset,\n        batch_size=opt.batch_size,\n        shuffle=False,\n        num_workers=0,\n        pin_memory=True,\n        drop_last=False,\n    )\n    \n    # if opt.gradient_checkpointing:\n    #     model.enable_gradient_checkpointing()\n    \n    #params = []\n    # for name, param in unet.named_parameters():\n    #     #if name.startswith(tuple(('up_blocks', 'mid_block', 'conv_out'))):\n    #     params.append(param) \n    # for name, param in conv.named_parameters():\n    #     params.append(param) \n    params = []\n    for name, param in model.named_parameters():\n        if name.startswith('unet.'):\n            #print(name)\n            params.append(param) \n        elif not name.startswith(tuple(('unet2', 'vae', 'tokenizer', 'text_encoder', 'scheduler', 'lpips'))):\n            #print(name)\n            params.append(param)\n\n    # optimizer\n    optimizer = torch.optim.AdamW(params, lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))\n    # optimizer\n    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))\n\n    # scheduler (per-iteration)\n    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3000, eta_min=1e-6)\n    total_steps = opt.num_epochs * len(train_dataloader)\n    pct_start = 3000 / total_steps\n    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start)\n\n    # accelerate\n    model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, test_dataloader, scheduler\n    )\n\n    # loop\n    for epoch in range(opt.num_epochs):\n        # train\n        model.train()\n        total_loss = 0\n        total_psnr = 0\n\n        for i, data in enumerate(train_dataloader):\n            with accelerator.accumulate(model):\n\n                optimizer.zero_grad()\n\n                step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs\n\n                out = model(data, step_ratio)\n                loss = out['loss']\n                psnr = out['psnr']\n                accelerator.backward(loss)\n                \n                writer.add_scalar('loss', loss.item(), epoch*len(train_dataloader)+i)\n                #writer.add_scalar('loss_mse', out['loss_mse'].item(), epoch*len(train_dataloader)+i)\n                writer.add_scalar('loss_mse_image', out['loss_mse_image'].item(), epoch*len(train_dataloader)+i)\n                writer.add_scalar('loss_mse_alpha', out['loss_mse_alpha'].item(), epoch*len(train_dataloader)+i)\n                if step_ratio> 0:\n                    writer.add_scalar('loss_lpips', out['loss_lpips'].item(), epoch*len(train_dataloader)+i)\n                writer.add_scalar('psnr', psnr.item(), epoch*len(train_dataloader)+i)\n                writer.add_scalar('lr', scheduler.get_last_lr()[0], epoch*len(train_dataloader)+i)\n                # gradient clipping\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip)\n\n                optimizer.step()\n                scheduler.step()\n\n                total_loss += loss.detach()\n                total_psnr += psnr.detach()\n\n            if accelerator.is_main_process:\n                # logging\n                if i % 100 == 0:\n                    mem_free, mem_total = torch.cuda.mem_get_info()    \n                    print(f\"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f} loss_mse: {out['loss_mse_image']:.6f}\")\n\n                # save log images\n                if i % 200 == 0:\n                    ## FIXME\n                    ## 3 ------>4 \n                    with torch.no_grad():\n                        # gt_images = (vae.decode(data['images_output'][0, :8].detach().to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5\n                        # gt_images = gt_images.clamp(0,1).float().unsqueeze(0).detach().cpu().numpy() \n                        # #gt_images = data['images_output'][:1].detach().cpu().numpy() # [B, V, 3, output_size, output_size]\n                        # gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]\n                        # kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images)\n\n                        gt_alphas = data['masks_output'].clamp(0,1).float().detach().cpu().numpy() # [B, V, 1, output_size, output_size]\n                        gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1)\n                        kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas)\n                        \n                        # gt_images_ori = (vae.decode((data['images_output'].detach()*data['masks_output']+out['white_latent'].detach()*(1-data['masks_output']))[0, :8].to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5\n                        # gt_images_ori = gt_images_ori.clamp(0,1).float().unsqueeze(0).detach().cpu().numpy() \n                        # gt_images_ori = gt_images_ori.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images_ori.shape[1] * gt_images_ori.shape[3], 3) # [B*output_size, V*output_size, 3]\n                        # kiui.write_image(f'{opt.workspace}/train_gt_images_ori_{epoch}_{i}.jpg', gt_images_ori)\n                        \n                        gt_noise_images = out[\"gt_noise\"].clamp(0,1).float().detach().cpu().numpy()\n                        #gt_noise_images = gt_noise_images.transpose(0, 2, 3, 1).reshape(-1, gt_noise_images.shape[2], 3)\n                        gt_noise_images = gt_noise_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_noise_images.shape[1] * gt_noise_images.shape[3], 3)\n                        kiui.write_image(f'{opt.workspace}/train_gt_noise_images_{epoch}_{i}.jpg', gt_noise_images)\n\n                        gt_images = data['images2_output'].clamp(0,1).float().detach().cpu().numpy() # [B, V, 3, output_size, output_size]\n                        gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]\n                    \n                        # data['images_output'] = (vae.decode(data['images_output'][0, :4].to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5\n                        # gt_images = data['images_output'].clamp(0,1).float().unsqueeze(0).detach().cpu().numpy() \n                        #gt_images = data['images_output'][:1].detach().cpu().numpy() # [B, V, 3, output_size, output_size]\n                        # gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]\n                        kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images)\n                        \n                        # out['images_pred'] = (vae.decode(out['images_pred'][0, :8].detach().to(dtype=torch.bfloat16)/ 0.18215).sample +1)*0.5\n                        pred_images = out['images_pred'].clamp(0,1).float().detach().cpu().numpy() \n                        #pred_images = out['images_pred'].reshape(data['images_output'].shape[0],data['images_output'].shape[1], *out['images_pred'].shape[1:]).detach().cpu().numpy() \n\n                        #pred_images = out['images_pred'][:1].detach().cpu().numpy() # [B, V, 3, output_size, output_size]\n                        pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)\n                        kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images)\n  \n                        # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]\n                        # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1)\n                        # kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas)\n\n        total_loss = accelerator.gather_for_metrics(total_loss).mean()\n        total_psnr = accelerator.gather_for_metrics(total_psnr).mean()\n        if accelerator.is_main_process:\n            total_loss /= len(train_dataloader)\n            total_psnr /= len(train_dataloader)\n            accelerator.print(f\"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}\")\n            \n        # checkpoint\n        if epoch % 1 == 0 or epoch == opt.num_epochs - 1:\n            accelerator.wait_for_everyone()\n            accelerator.save_model(model, opt.workspace)\n        \n        \n        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n        # if opt.checkpoints_total_limit is not None:\n\n        #     checkpoints = os.listdir(opt.workspace)\n        #     checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n        #     checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n        #     # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n        #     if len(checkpoints) >= opt.checkpoints_total_limit:\n        #         num_to_remove = len(checkpoints) - opt.checkpoints_total_limit + 1\n        #         removing_checkpoints = checkpoints[0:num_to_remove]\n\n        #         print(\n        #             f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n        #         )\n        #         print(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n        #         for removing_checkpoint in removing_checkpoints:\n        #             removing_checkpoint = os.path.join(opt.workspace, removing_checkpoint)\n        #             shutil.rmtree(removing_checkpoint)\n\n        # save_path = os.path.join(opt.workspace, f\"checkpoint-{epoch}\")\n        # accelerator.save_state(save_path)\n        #print(f\"Saved state to {save_path}\")\n        \n        # eval\n        with torch.no_grad():\n            model.eval()\n            total_psnr = 0\n            for i, data in enumerate(test_dataloader):\n\n                out = model(data)\n    \n                psnr = out['psnr']\n                total_psnr += psnr.detach()\n                \n                # save some images\n                if accelerator.is_main_process:\n                    gt_images = data['images2_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]\n                    gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]\n                    kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images)\n\n                    pred_images = out['images_pred'].clamp(0,1).float().detach().cpu().numpy() # [B, V, 3, output_size, output_size]\n                    pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)\n                    kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images)\n\n                    # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]\n                    # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1)\n                    # kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas)\n\n            torch.cuda.empty_cache()\n\n            total_psnr = accelerator.gather_for_metrics(total_psnr).mean()\n            if accelerator.is_main_process:\n                total_psnr /= len(test_dataloader)\n                accelerator.print(f\"[eval] epoch: {epoch} psnr: {psnr:.4f}\")\n    \n    writer.close()\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "mvdream/mv_unet.py",
    "content": "import math\nimport numpy as np\nfrom inspect import isfunction\nfrom typing import Optional, Any, List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nfrom diffusers.configuration_utils import ConfigMixin\nfrom diffusers.models.modeling_utils import ModelMixin\n\n# require xformers!\nimport xformers\nimport xformers.ops\n\nfrom kiui.cam import orbit_camera\n\ndef get_camera(\n    num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,\n):\n    angle_gap = azimuth_span / num_frames\n    cameras = []\n    for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):\n        \n        pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4]\n\n        # opengl to blender\n        if blender_coord:\n            pose[2] *= -1\n            pose[[1, 2]] = pose[[2, 1]]\n\n        cameras.append(pose.flatten())\n\n    if extra_view:\n        cameras.append(np.zeros_like(cameras[0]))\n\n    return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period)\n            * torch.arange(start=0, end=half, dtype=torch.float32)\n            / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None] * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat(\n                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1\n            )\n    else:\n        embedding = repeat(timesteps, \"b -> b d\", d=dim)\n    # import pdb; pdb.set_trace()\n    return embedding\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef default(val, d):\n    if val is not None:\n        return val\n    return d() if isfunction(d) else d\n\n\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = (\n            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())\n            if not glu\n            else GEGLU(dim, inner_dim)\n        )\n\n        self.net = nn.Sequential(\n            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\nclass MemoryEfficientCrossAttention(nn.Module):\n    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223\n    def __init__(\n            self, \n            query_dim, \n            context_dim=None, \n            heads=8, \n            dim_head=64, \n            dropout=0.0,\n            ip_dim=0,\n            ip_weight=1,\n        ):\n        super().__init__()\n        \n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.heads = heads\n        self.dim_head = dim_head\n\n        self.ip_dim = ip_dim\n        self.ip_weight = ip_weight\n\n        if self.ip_dim > 0:\n            self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)\n            self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)\n        )\n        self.attention_op: Optional[Any] = None\n\n    def forward(self, x, context=None):\n        q = self.to_q(x)\n        context = default(context, x)\n\n        if self.ip_dim > 0:\n            # context： [B, 77 + 16(ip), 1024]\n            token_len = context.shape[1]\n            context_ip = context[:, -self.ip_dim :, :]\n            k_ip = self.to_k_ip(context_ip)\n            v_ip = self.to_v_ip(context_ip)\n            context = context[:, : (token_len - self.ip_dim), :]\n\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        b, _, _ = q.shape\n        q, k, v = map(\n            lambda t: t.unsqueeze(3)\n            .reshape(b, t.shape[1], self.heads, self.dim_head)\n            .permute(0, 2, 1, 3)\n            .reshape(b * self.heads, t.shape[1], self.dim_head)\n            .contiguous(),\n            (q, k, v),\n        )\n\n        # actually compute the attention, what we cannot get enough of\n        out = xformers.ops.memory_efficient_attention(\n            q, k, v, attn_bias=None, op=self.attention_op\n        )\n\n        if self.ip_dim > 0:\n            k_ip, v_ip = map(\n                lambda t: t.unsqueeze(3)\n                .reshape(b, t.shape[1], self.heads, self.dim_head)\n                .permute(0, 2, 1, 3)\n                .reshape(b * self.heads, t.shape[1], self.dim_head)\n                .contiguous(),\n                (k_ip, v_ip),\n            )\n            # actually compute the attention, what we cannot get enough of\n            out_ip = xformers.ops.memory_efficient_attention(\n                q, k_ip, v_ip, attn_bias=None, op=self.attention_op\n            )\n            out = out + self.ip_weight * out_ip\n\n        out = (\n            out.unsqueeze(0)\n            .reshape(b, self.heads, out.shape[1], self.dim_head)\n            .permute(0, 2, 1, 3)\n            .reshape(b, out.shape[1], self.heads * self.dim_head)\n        )\n        return self.to_out(out)\n\n\nclass BasicTransformerBlock3D(nn.Module):\n    \n    def __init__(\n        self,\n        dim,\n        n_heads,\n        d_head,\n        context_dim,\n        dropout=0.0,\n        gated_ff=True,\n        ip_dim=0,\n        ip_weight=1,\n    ):\n        super().__init__()\n\n        self.attn1 = MemoryEfficientCrossAttention(\n            query_dim=dim,\n            context_dim=None, # self-attention\n            heads=n_heads,\n            dim_head=d_head,\n            dropout=dropout,\n        )\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = MemoryEfficientCrossAttention(\n            query_dim=dim,\n            context_dim=context_dim,\n            heads=n_heads,\n            dim_head=d_head,\n            dropout=dropout,\n            # ip only applies to cross-attention\n            ip_dim=ip_dim,\n            ip_weight=ip_weight,\n        ) \n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n\n    def forward(self, x, context=None, num_frames=1):\n        x = rearrange(x, \"(b f) l c -> b (f l) c\", f=num_frames).contiguous()\n        x = self.attn1(self.norm1(x), context=None) + x\n        x = rearrange(x, \"b (f l) c -> (b f) l c\", f=num_frames).contiguous()\n        x = self.attn2(self.norm2(x), context=context) + x\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer3D(nn.Module):\n\n    def __init__(\n        self,\n        in_channels,\n        n_heads,\n        d_head,\n        context_dim, # cross attention input dim\n        depth=1,\n        dropout=0.0,\n        ip_dim=0,\n        ip_weight=1,\n    ):\n        super().__init__()\n\n        if not isinstance(context_dim, list):\n            context_dim = [context_dim]\n\n        self.in_channels = in_channels\n\n        inner_dim = n_heads * d_head\n        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n        self.proj_in = nn.Linear(in_channels, inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock3D(\n                    inner_dim,\n                    n_heads,\n                    d_head,\n                    context_dim=context_dim[d],\n                    dropout=dropout,\n                    ip_dim=ip_dim,\n                    ip_weight=ip_weight,\n                )\n                for d in range(depth)\n            ]\n        )\n        \n        self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))\n        \n\n    def forward(self, x, context=None, num_frames=1):\n        # note: if no context is given, cross-attention defaults to self-attention\n        if not isinstance(context, list):\n            context = [context]\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        x = rearrange(x, \"b c h w -> b (h w) c\").contiguous()\n        x = self.proj_in(x)\n        for i, block in enumerate(self.transformer_blocks):\n            x = block(x, context=context[i], num_frames=num_frames)\n        x = self.proj_out(x)\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w).contiguous()\n        \n        return x + x_in\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q, k, v = map(\n            lambda t: t.reshape(b, t.shape[1], self.heads, -1)\n            .transpose(1, 2)\n            .reshape(b, self.heads, t.shape[1], -1)\n            .contiguous(),\n            (q, k, v),\n        )\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass Resampler(nn.Module):\n    def __init__(\n        self,\n        dim=1024,\n        depth=8,\n        dim_head=64,\n        heads=16,\n        num_queries=8,\n        embedding_dim=768,\n        output_dim=1024,\n        ff_mult=4,\n    ):\n        super().__init__()\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)\n        self.proj_in = nn.Linear(embedding_dim, dim)\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        nn.Sequential(\n                            nn.LayerNorm(dim),\n                            nn.Linear(dim, dim * ff_mult, bias=False),\n                            nn.GELU(),\n                            nn.Linear(dim * ff_mult, dim, bias=False),\n                        )\n                    ]\n                )\n            )\n\n    def forward(self, x):\n        latents = self.latents.repeat(x.size(0), 1, 1)\n        x = self.proj_in(x)\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n\nclass CondSequential(nn.Sequential):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb, context=None, num_frames=1):\n        for layer in self:\n            if isinstance(layer, ResBlock):\n                x = layer(x, emb)\n            elif isinstance(layer, SpatialTransformer3D):\n                x = layer(x, context, num_frames=num_frames)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        if use_conv:\n            self.conv = conv_nd(\n                dims, self.channels, self.out_channels, 3, padding=padding\n            )\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(\n                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode=\"nearest\"\n            )\n        else:\n            x = F.interpolate(x, scale_factor=2, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(\n                dims,\n                self.channels,\n                self.out_channels,\n                3,\n                stride=stride,\n                padding=padding,\n            )\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(nn.Module):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param up: if True, use this block for upsampling.\n    :param down: if True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        emb_channels,\n        dropout,\n        out_channels=None,\n        use_conv=False,\n        use_scale_shift_norm=False,\n        dims=2,\n        up=False,\n        down=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_scale_shift_norm = use_scale_shift_norm\n\n        self.in_layers = nn.Sequential(\n            nn.GroupNorm(32, channels),\n            nn.SiLU(),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims)\n            self.x_upd = Upsample(channels, False, dims)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims)\n            self.x_upd = Downsample(channels, False, dims)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.emb_layers = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(\n                emb_channels,\n                2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            nn.GroupNorm(32, self.out_channels),\n            nn.SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(\n                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)\n            ),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 3, padding=1\n            )\n        else:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)\n\n    def forward(self, x, emb):\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = torch.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass MultiViewUNetModel(ModelMixin, ConfigMixin):\n    \"\"\"\n    The full multi-view UNet model with attention, timestep embedding and camera embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: if specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    :param use_new_attention_order: use a different attention pattern for potentially\n                                    increased efficiency.\n    :param camera_dim: dimensionality of camera input.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        num_classes=None,\n        num_heads=-1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        transformer_depth=1,\n        context_dim=None,\n        n_embed=None,\n        num_attention_blocks=None,\n        adm_in_channels=None,\n        camera_dim=None,\n        ip_dim=0, # imagedream uses ip_dim > 0\n        ip_weight=1.0,\n        **kwargs,\n    ):\n        super().__init__()\n        assert context_dim is not None\n        \n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert (\n                num_head_channels != -1\n            ), \"Either num_heads or num_head_channels has to be set\"\n\n        if num_head_channels == -1:\n            assert (\n                num_heads != -1\n            ), \"Either num_heads or num_head_channels has to be set\"\n\n        self.image_size = image_size\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        if isinstance(num_res_blocks, int):\n            self.num_res_blocks = len(channel_mult) * [num_res_blocks]\n        else:\n            if len(num_res_blocks) != len(channel_mult):\n                raise ValueError(\n                    \"provide num_res_blocks either as an int (globally constant) or \"\n                    \"as a list/tuple (per-level) with the same length as channel_mult\"\n                )\n            self.num_res_blocks = num_res_blocks\n        \n        if num_attention_blocks is not None:\n            assert len(num_attention_blocks) == len(self.num_res_blocks)\n            assert all(\n                map(\n                    lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],\n                    range(len(num_attention_blocks)),\n                )\n            )\n            print(\n                f\"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. \"\n                f\"This option has LESS priority than attention_resolutions {attention_resolutions}, \"\n                f\"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, \"\n                f\"attention will still not be set.\"\n            )\n\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n        self.predict_codebook_ids = n_embed is not None\n\n        self.ip_dim = ip_dim\n        self.ip_weight = ip_weight\n\n        if self.ip_dim > 0:\n            self.image_embed = Resampler(\n                dim=context_dim,\n                depth=4,\n                dim_head=64,\n                heads=12,\n                num_queries=ip_dim,  # num token\n                embedding_dim=1280,\n                output_dim=context_dim,\n                ff_mult=4,\n            )\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            nn.Linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            nn.Linear(time_embed_dim, time_embed_dim),\n        )\n\n        if camera_dim is not None:\n            time_embed_dim = model_channels * 4\n            self.camera_embed = nn.Sequential(\n                nn.Linear(camera_dim, time_embed_dim),\n                nn.SiLU(),\n                nn.Linear(time_embed_dim, time_embed_dim),\n            )\n\n        if self.num_classes is not None:\n            if isinstance(self.num_classes, int):\n                self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)\n            elif self.num_classes == \"continuous\":\n                # print(\"setting up linear c_adm embedding layer\")\n                self.label_emb = nn.Linear(1, time_embed_dim)\n            elif self.num_classes == \"sequential\":\n                assert adm_in_channels is not None\n                self.label_emb = nn.Sequential(\n                    nn.Sequential(\n                        nn.Linear(adm_in_channels, time_embed_dim),\n                        nn.SiLU(),\n                        nn.Linear(time_embed_dim, time_embed_dim),\n                    )\n                )\n            else:\n                raise ValueError()\n\n        self.input_blocks = nn.ModuleList(\n            [\n                CondSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for nr in range(self.num_res_blocks[level]):\n                layers: List[Any] = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n\n                    if num_attention_blocks is None or nr < num_attention_blocks[level]:\n                        layers.append(\n                            SpatialTransformer3D(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                context_dim=context_dim,\n                                depth=transformer_depth,\n                                ip_dim=self.ip_dim,\n                                ip_weight=self.ip_weight,\n                            )\n                        )\n                self.input_blocks.append(CondSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    CondSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n        \n        self.middle_block = CondSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            SpatialTransformer3D(\n                ch,\n                num_heads,\n                dim_head,\n                context_dim=context_dim,\n                depth=transformer_depth,\n                ip_dim=self.ip_dim,\n                ip_weight=self.ip_weight,\n            ), \n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(self.num_res_blocks[level] + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n\n                    if num_attention_blocks is None or i < num_attention_blocks[level]:\n                        layers.append(\n                            SpatialTransformer3D(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                context_dim=context_dim,\n                                depth=transformer_depth,\n                                ip_dim=self.ip_dim,\n                                ip_weight=self.ip_weight,\n                            )\n                        )\n                if level and i == self.num_res_blocks[level]:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                    ds //= 2\n                self.output_blocks.append(CondSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            nn.GroupNorm(32, ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n        if self.predict_codebook_ids:\n            self.id_predictor = nn.Sequential(\n                nn.GroupNorm(32, ch),\n                conv_nd(dims, model_channels, n_embed, 1),\n                # nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits\n            )\n\n    def forward(\n        self,\n        x,\n        timesteps=None,\n        context=None,\n        y=None,\n        camera=None,\n        num_frames=1,\n        ip=None,\n        ip_img=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :param num_frames: a integer indicating number of frames for tensor reshaping.\n        :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).\n        \"\"\"\n        assert (\n            x.shape[0] % num_frames == 0\n        ), \"input batch size must be dividable by num_frames!\"\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional\"\n\n        hs = []\n\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)\n\n        emb = self.time_embed(t_emb)\n\n        if self.num_classes is not None:\n            assert y is not None\n            assert y.shape[0] == x.shape[0]\n            emb = emb + self.label_emb(y)\n\n        # Add camera embeddings\n        if camera is not None:\n            emb = emb + self.camera_embed(camera)\n        \n        # imagedream variant\n        if self.ip_dim > 0:\n            x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]\n            ip_emb = self.image_embed(ip)\n            context = torch.cat((context, ip_emb), 1)\n\n        h = x\n        for module in self.input_blocks:\n            h = module(h, emb, context, num_frames=num_frames)\n            hs.append(h)\n        h = self.middle_block(h, emb, context, num_frames=num_frames)\n        for module in self.output_blocks:\n            h = torch.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, context, num_frames=num_frames)\n        h = h.type(x.dtype)\n        if self.predict_codebook_ids:\n            return self.id_predictor(h)\n        else:\n            return self.out(h)"
  },
  {
    "path": "mvdream/pipeline_mvdream.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport inspect\nimport numpy as np\nfrom typing import Callable, List, Optional, Union\nfrom transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor\nfrom diffusers import AutoencoderKL, DiffusionPipeline\nfrom diffusers.utils import (\n    deprecate,\n    is_accelerate_available,\n    is_accelerate_version,\n    logging,\n)\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.schedulers import DDIMScheduler\nfrom diffusers.utils.torch_utils import randn_tensor\n\nfrom mvdream.mv_unet import MultiViewUNetModel, get_camera\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass MVDreamPipeline(DiffusionPipeline):\n\n    _optional_components = [\"feature_extractor\", \"image_encoder\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        unet: MultiViewUNetModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder: CLIPTextModel,\n        scheduler: DDIMScheduler,\n        # imagedream variant\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModel,\n        requires_safety_checker: bool = False,\n    ):\n        super().__init__()\n\n        if hasattr(scheduler.config, \"steps_offset\") and scheduler.config.steps_offset != 1:  # type: ignore\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"  # type: ignore\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\n                \"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False\n            )\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if hasattr(scheduler.config, \"clip_sample\") and scheduler.config.clip_sample is True:  # type: ignore\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\n                \"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False\n            )\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            unet=unet,\n            scheduler=scheduler,\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding.\n\n        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several\n        steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in\n        several steps. This is useful to save a large amount of memory and to allow the processing of larger images.\n        \"\"\"\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_tiling()\n\n    def enable_sequential_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,\n        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a\n        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.\n        Note that offloading happens on a submodule basis. Memory savings are higher than with\n        `enable_model_cpu_offload`, but performance is lower.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.14.0\"):\n            from accelerate import cpu_offload\n        else:\n            raise ImportError(\n                \"`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher\"\n            )\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\", silence_dtype_warnings=True)\n            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)\n\n        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:\n            cpu_offload(cpu_offloaded_model, device)\n\n    def enable_model_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared\n        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`\n        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with\n        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate import cpu_offload_with_hook\n        else:\n            raise ImportError(\n                \"`enable_model_offload` requires `accelerate v0.17.0` or higher.\"\n            )\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\", silence_dtype_warnings=True)\n            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)\n\n        hook = None\n        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:\n            _, hook = cpu_offload_with_hook(\n                cpu_offloaded_model, device, prev_module_hook=hook\n            )\n\n        # We'll offload the last model manually.\n        self.final_offload_hook = hook\n\n    @property\n    def _execution_device(self):\n        r\"\"\"\n        Returns the device on which the pipeline's models will be executed. After calling\n        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module\n        hooks.\n        \"\"\"\n        if not hasattr(self.unet, \"_hf_hook\"):\n            return self.device\n        for module in self.unet.modules():\n            if (\n                hasattr(module, \"_hf_hook\")\n                and hasattr(module._hf_hook, \"execution_device\")\n                and module._hf_hook.execution_device is not None\n            ):\n                return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance: bool,\n        negative_prompt=None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(\n                f\"`prompt` should be either a string or a list of strings, but got {type(prompt)}.\"\n            )\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(\n            prompt, padding=\"longest\", return_tensors=\"pt\"\n        ).input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n            text_input_ids, untruncated_ids\n        ):\n            removed_text = self.tokenizer.batch_decode(\n                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n            )\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n\n        if (\n            hasattr(self.text_encoder.config, \"use_attention_mask\")\n            and self.text_encoder.config.use_attention_mask\n        ):\n            attention_mask = text_inputs.attention_mask.to(device)\n        else:\n            attention_mask = None\n\n        prompt_embeds = self.text_encoder(\n            text_input_ids.to(device),\n            attention_mask=attention_mask,\n        )\n        prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(\n            bs_embed * num_images_per_prompt, seq_len, -1\n        )\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if (\n                hasattr(self.text_encoder.config, \"use_attention_mask\")\n                and self.text_encoder.config.use_attention_mask\n            ):\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(\n                dtype=self.text_encoder.dtype, device=device\n            )\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(\n                1, num_images_per_prompt, 1\n            )\n            negative_prompt_embeds = negative_prompt_embeds.view(\n                batch_size * num_images_per_prompt, seq_len, -1\n            )\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(\n            inspect.signature(self.scheduler.step).parameters.keys()\n        )\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(\n            inspect.signature(self.scheduler.step).parameters.keys()\n        )\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(\n                shape, generator=generator, device=device, dtype=dtype\n            )\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def encode_image(self, image, device, num_images_per_prompt):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if image.dtype == np.float32:\n            image = (image * 255).astype(np.uint8)\n            \n        image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n        image = image.to(device=device, dtype=dtype)\n        \n        image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n\n        return torch.zeros_like(image_embeds), image_embeds\n\n    def encode_image_latents(self, image, device, num_images_per_prompt):\n        \n        dtype = next(self.image_encoder.parameters()).dtype\n\n        image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]\n        image = 2 * image - 1\n        image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)\n        image = image.to(dtype=dtype)\n\n        posterior = self.vae.encode(image).latent_dist\n        latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]\n        latents = latents.repeat_interleave(num_images_per_prompt, dim=0)\n\n        return torch.zeros_like(latents), latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: str = \"\",\n        image: Optional[np.ndarray] = None,\n        height: int = 256,\n        width: int = 256,\n        elevation: float = 0,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.0,\n        negative_prompt: str = \"\",\n        num_images_per_prompt: int = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        output_type: Optional[str] = \"numpy\", # pil, numpy, latents\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        num_frames: int = 4,\n        device=torch.device(\"cuda:0\"),\n    ):\n        self.unet = self.unet.to(device=device)\n        self.vae = self.vae.to(device=device)\n        self.text_encoder = self.text_encoder.to(device=device)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # imagedream variant\n        if image is not None:\n            assert isinstance(image, np.ndarray) and image.dtype == np.float32\n            self.image_encoder = self.image_encoder.to(device=device)\n            image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)\n            image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)\n            \n        _prompt_embeds = self._encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n        )  # type: ignore\n        prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)\n\n        # Prepare latent variables\n        actual_num_frames = num_frames if image is None else num_frames + 1\n        latents: torch.Tensor = self.prepare_latents(\n            actual_num_frames * num_images_per_prompt,\n            4,\n            height,\n            width,\n            prompt_embeds_pos.dtype,\n            device,\n            generator,\n            None,\n        )\n\n        if image is not None:\n            camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device)\n        else:\n            camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device)\n        camera = camera.repeat_interleave(num_images_per_prompt, dim=0)\n\n        # Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                multiplier = 2 if do_classifier_free_guidance else 1\n                latent_model_input = torch.cat([latents] * multiplier)\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                unet_inputs = {\n                    'x': latent_model_input,\n                    'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),\n                    'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),\n                    'num_frames': actual_num_frames,\n                    'camera': torch.cat([camera] * multiplier),\n                }\n\n                if image is not None:\n                    unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)\n                    unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat\n                \n                # predict the noise residual\n                noise_pred = self.unet.forward(**unet_inputs)\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (\n                        noise_pred_text - noise_pred_uncond\n                    )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents: torch.Tensor = self.scheduler.step(\n                    noise_pred, t, latents, **extra_step_kwargs, return_dict=False\n                )[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or (\n                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0\n                ):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)  # type: ignore\n\n        # Post-processing\n        if output_type == \"latent\":\n            image = latents\n        elif output_type == \"pil\":\n            image = self.decode_latents(latents)\n            image = self.numpy_to_pil(image)\n        else: # numpy\n            image = self.decode_latents(latents)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        return image"
  }
]