Full Code of smthemex/ComfyUI_DiffuEraser for AI

main c8461876ad09 cached
77 files
2.5 MB
652.8k tokens
761 symbols
1 requests
Download .txt
Showing preview only (2,631K chars total). Download the full file or copy to clipboard to get everything.
Repository: smthemex/ComfyUI_DiffuEraser
Branch: main
Commit: c8461876ad09
Files: 77
Total size: 2.5 MB

Directory structure:
gitextract_svppjk7y/

├── LICENSE
├── README.md
├── __init__.py
├── diffueraser_node.py
├── example_workflows/
│   └── differaser.json
├── libs/
│   ├── __init__.py
│   ├── brushnet_CA.py
│   ├── diffueraser.py
│   ├── pipeline_diffueraser.py
│   ├── transformer_temporal.py
│   ├── unet_2d_blocks.py
│   ├── unet_2d_condition.py
│   ├── unet_3d_blocks.py
│   ├── unet_motion_model.py
│   └── v1-inference.yaml
├── node_utils.py
├── propainter/
│   ├── RAFT/
│   │   ├── __init__.py
│   │   ├── corr.py
│   │   ├── datasets.py
│   │   ├── demo.py
│   │   ├── extractor.py
│   │   ├── raft.py
│   │   ├── update.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── augmentor.py
│   │       ├── flow_viz.py
│   │       ├── flow_viz_pt.py
│   │       ├── frame_utils.py
│   │       └── utils.py
│   ├── core/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   ├── dist.py
│   │   ├── loss.py
│   │   ├── lr_scheduler.py
│   │   ├── metrics.py
│   │   ├── prefetch_dataloader.py
│   │   ├── trainer.py
│   │   ├── trainer_flow_w_edge.py
│   │   └── utils.py
│   ├── inference.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── canny/
│   │   │   ├── __init__.py
│   │   │   ├── canny_filter.py
│   │   │   ├── filter.py
│   │   │   ├── gaussian.py
│   │   │   ├── kernels.py
│   │   │   └── sobel.py
│   │   ├── misc.py
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   ├── base_module.py
│   │   │   ├── deformconv.py
│   │   │   ├── flow_comp_raft.py
│   │   │   ├── flow_loss_utils.py
│   │   │   ├── sparse_transformer.py
│   │   │   └── spectral_norm.py
│   │   ├── propainter.py
│   │   ├── recurrent_flow_completion.py
│   │   └── vgg_arch.py
│   └── utils/
│       ├── __init__.py
│       ├── download_util.py
│       ├── file_client.py
│       ├── flow_util.py
│       └── img_util.py
├── pyproject.toml
├── requirements.txt
├── run_diffueraser.py
└── sd15_repo/
    ├── feature_extractor/
    │   └── preprocessor_config.json
    ├── model_index.json
    ├── safety_checker/
    │   └── config.json
    ├── scheduler/
    │   └── scheduler_config.json
    ├── text_encoder/
    │   └── config.json
    ├── tokenizer/
    │   ├── merges.txt
    │   ├── special_tokens_map.json
    │   ├── tokenizer_config.json
    │   └── vocab.json
    ├── unet/
    │   └── config.json
    └── vae/
        └── config.json

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2025 smthemex

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# ComfyUI_DiffuEraser
[DiffuEraser](https://github.com/lixiaowen-xw/DiffuEraser) is  a diffusion model for video Inpainting, you can use it in ComfyUI

# Update
* 使用官方推荐的vae文件,clip-l改成comfyUI 默认的,由此剔除掉sd15底模;
* 对于水印,mask_dilation_iter(遮罩膨胀系数)应适当调低,比如2,常规使用Propainter的采样也就够了(最大边是960,应用了再超分吧) 
* use cofmyUI v3 mode,fix bugs,add new diffuser support,you can run 1280*720 (12GVRAM) now
* 修复不少bug,现在12G也能跑1280*720,DiffuEraser的sample 节点的 blend支持2种输出,关闭为降低闪烁,开启为使用合成,避免loop循环的反复加载模型


# 1. Installation

In the ./ComfyUI /custom_nodes directory, run the following:   
```
git clone https://github.com/smthemex/ComfyUI_DiffuEraser.git
```
---

# 2. Requirements  
* no need, because it's base in sd1.5 ,Perhaps someone may be missing the library.没什么特殊的库,懒得删了
```
pip install -r requirements.txt
```
# 3. Models
* vae [links](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
* clip-l, comfyUI normal
* pcm 1.5 lora [address](https://huggingface.co/wangfuyun/PCM_Weights/tree/main/sd15)   pcm_sd15_smallcfg_2step_converted.safetensors  #example
* ProPainter [address](https://github.com/sczhou/ProPainter/releases/tag/v0.1.0) # below example
* unet and brushnet [address](https://huggingface.co/lixiaowen/diffuEraser/tree/main)  # below example

```
--  ComfyUI/models/vae
    |-- sd-vae-ft-mse.safetensors #vae
--  ComfyUI/models/clip
    |-- clip_l.safetensors # comfyUI normal
--  ComfyUI/models/DiffuEraservae
     |--brushnet
        |-- config.json
        |-- diffusion_pytorch_model.safetensors
     |--unet_main
        |-- config.json
        |-- diffusion_pytorch_model.safetensors
     |--propainter
        |-- ProPainter.pth
        |-- raft-things.pth
        |-- recurrent_flow_completion.pth
```
* If use video to mask #可以用RMBG或者BiRefNet模型脱底
```
-- any/path/briaai/RMBG-2.0   # or auto download 
        |--config.json
        |--model.safetensors
        |--birefnet.py
        |--BiRefNet_config.py
Or
-- any/path/ZhengPeng7/BiRefNet   # or auto download 
        |--config.json
        |--model.safetensors
        |--birefnet.py
        |--BiRefNet_config.py
        |--handler.py
```
  
# 4 Example
![](https://github.com/smthemex/ComfyUI_DiffuEraser/blob/main/example_workflows/example.png)
* use single mask 
![](https://github.com/smthemex/ComfyUI_DiffuEraser/blob/main/example_workflows/example1.png)

# 5.Citation
```
@misc{li2025diffueraserdiffusionmodelvideo,
   title={DiffuEraser: A Diffusion Model for Video Inpainting}, 
   author={Xiaowen Li and Haolan Xue and Peiran Ren and Liefeng Bo},
   year={2025},
   eprint={2501.10018},
   archivePrefix={arXiv},
   primaryClass={cs.CV},
   url={https://arxiv.org/abs/2501.10018}, 
}
```
```
@inproceedings{zhou2023propainter,
   title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting},
   author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change},
   booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)},
   year={2023}
}
```
```
@misc{ju2024brushnet,
  title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion}, 
  author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu},
  year={2024},
  eprint={2403.06976},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```
```
@article{BiRefNet,
  title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
  author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
  journal={CAAI Artificial Intelligence Research},
  year={2024}
}

```


================================================
FILE: __init__.py
================================================

from .diffueraser_node import *

================================================
FILE: diffueraser_node.py
================================================
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
import os
import torch
import gc
import numpy as np
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import nodes
import comfy.model_management as mm
from .node_utils import  load_images,tensor2pil_list,image2masks,nomarl_upscale
import folder_paths
from .run_diffueraser import load_diffueraser,load_propainter
from diffusers.hooks import apply_group_offloading
import copy

MAX_SEED = np.iinfo(np.int32).max
current_node_path = os.path.dirname(os.path.abspath(__file__))
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# add checkpoints dir
DiffuEraser_weigths_path = os.path.join(folder_paths.models_dir, "DiffuEraser")
if not os.path.exists(DiffuEraser_weigths_path):
    os.makedirs(DiffuEraser_weigths_path)
folder_paths.add_model_folder_path("DiffuEraser", DiffuEraser_weigths_path)



class Propainter_Loader(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="Propainter_Loader",
            display_name="Propainter_Loader",
            category="DiffuEraser",
            inputs=[
                io.Combo.Input("propainter",options= ["none"] + folder_paths.get_filename_list("DiffuEraser") ),
                io.Combo.Input("flow",options= ["none"] + folder_paths.get_filename_list("DiffuEraser") ),
                io.Combo.Input("fix_raft",options= ["none"] + folder_paths.get_filename_list("DiffuEraser") ),
                io.Combo.Input("device",options= ["cpu","cuda","mps"] ),
            ],
            outputs=[
                io.Custom("Propainter_Loader").Output(display_name="model"),
                ],
            )
    @classmethod
    def execute(cls, propainter,flow,fix_raft,device) -> io.NodeOutput:
        ProPainter_path=folder_paths.get_full_path("DiffuEraser",propainter) if propainter!="none" else None
        flow_path=folder_paths.get_full_path("DiffuEraser",flow) if flow!="none" else None
        fix_raft_path=folder_paths.get_full_path("DiffuEraser",fix_raft) if fix_raft!="none" else None
        if fix_raft_path is  None or flow_path is  None or ProPainter_path is None:
            raise "need load all models"
        model=load_propainter(fix_raft_path,flow_path,ProPainter_path,device=device)
        return io.NodeOutput(model)


class DiffuEraser_Loader(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="DiffuEraser_Loader",
            display_name="DiffuEraser_Loader",
            category="DiffuEraser",
            inputs=[
                io.Combo.Input("vae",options= ["none"] + folder_paths.get_filename_list("vae") ),
                io.Combo.Input("lora",options= ["none"] + folder_paths.get_filename_list("loras") ),
            ],
            outputs=[
                io.Custom("DiffuEraser_Loader").Output(display_name="model"),
                ],
            )
    @classmethod
    def execute(cls, vae,lora) -> io.NodeOutput:
        ckpt_path=folder_paths.get_full_path("vae",vae) if vae!="none" else None
        pcm_lora_path=folder_paths.get_full_path("loras",lora) if lora!="none" else None
        #print("load lora model from:",pcm_lora_path)
        model=load_diffueraser(os.path.join(current_node_path,"sd15_repo"),DiffuEraser_weigths_path, ckpt_path,os.path.join(current_node_path,"libs/v1-inference.yaml"),pcm_lora_path,device)
        gc.collect()
        torch.cuda.empty_cache()
        return io.NodeOutput(model)


class DiffuEraser_PreData(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="DiffuEraser_PreData",
            display_name="DiffuEraser_PreData",
            category="DiffuEraser",
            inputs=[
                io.Image.Input("images"),
                io.String.Input("seg_repo",default="briaai/RMBG-2.0"),
                io.Image.Input("video_mask_image",optional=True),
                io.Mask.Input("video_mask",optional=True),
            ],
            outputs=[
                 io.Conditioning.Output(display_name="conditioning"),
                ],
            )
    @classmethod
    def execute(cls, images,seg_repo,video_mask_image=None,video_mask=None) -> io.NodeOutput:
        _,height,width,_  = images.size()
        height,width=(height-height%8, width-width%8)
        video_image=tensor2pil_list(images,width,height)
       
        if video_mask is None and video_mask_image is None and seg_repo:    # use rmbg or BiRefNet to make video to masks
            print("*********** Use input video and repo to make masks **************")
            init_mask=image2masks(seg_repo,video_image)
        elif video_mask_image is not None:

            if not  isinstance(video_mask_image,torch.Tensor):
                raise "video_mask_image is not a normal comfyUI image tensor, need a shape like  b,h,w,c"
            else:
                init_mask=tensor2pil_list(video_mask_image,width,height)
                
        elif video_mask is not None:
            if isinstance(video_mask,torch.Tensor) and len(video_mask)>3:
                raise "video_mask is not a normal comfyUI mask tensor, need a shape like  b,h,w"
            init_mask=tensor2pil_list( video_mask.reshape((-1, 1, video_mask.shape[-2], video_mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) ,width,height)
        else:   
            raise "no video_mask,you can enable video2mask and fill a rmbg or BiRefNet repo to generate mask from video_image,or link video_mask from other node"
        
        if len(init_mask)!=len(video_image) :
            if  len(init_mask)==1:
                init_mask=init_mask*len(video_image) # if use one mask to inpaint all frames
            else:
                if len(init_mask)>len(video_image):  
                    init_mask=init_mask[:len(video_image)]
                    print("init_mask length:",len(init_mask),"video_image length:",len(video_image))
                else:
                    init_mask=init_mask+init_mask[:len(video_image)-len(init_mask)]
                    print("init_mask length:",len(init_mask),"video_image length:",len(video_image))
        cond={"init_mask":init_mask,"video_image":video_image,"height":height,"width":width}
        return io.NodeOutput(cond)


class Propainter_Sampler(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="Propainter_Sampler",
            display_name="Propainter_Sampler",
            category="DiffuEraser",
            inputs=[
                io.Custom("Propainter_Loader").Input("model"),
                io.Conditioning.Input("conditioning"),
                io.Float.Input("fps", force_input=True),
                io.Int.Input("video_length", default=10, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Int.Input("mask_dilation_iter", default=2, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Int.Input("ref_stride", default=10, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Int.Input("neighbor_length", default=10, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Int.Input("subvideo_length", default=50, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
            ], 
            outputs=[
                io.Conditioning.Output(display_name="conditioning"),
                io.Image.Output(display_name="images"),
                ],
            )
    
    @classmethod
    def execute(cls, model,conditioning,fps,video_length,mask_dilation_iter,ref_stride,neighbor_length,subvideo_length) -> io.NodeOutput:
        

        model.to(device)
        conditioning["fps"]=fps
        conditioning["video_length"]=video_length
        conditioning["mask_dilation_iter"]=mask_dilation_iter
       
        Propainter_img=model.forward(copy.deepcopy(conditioning["video_image"]), copy.deepcopy(conditioning["init_mask"]),load_videobypath=False,video_length=video_length, height= conditioning["height"],width=conditioning["width"],
                        ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length,
                        mask_dilation = mask_dilation_iter,save_fps=fps) 
        conditioning["prioris"]=Propainter_img
        model.to("cpu")
        gc.collect()
        torch.cuda.empty_cache()
        images=load_images(Propainter_img)
        return io.NodeOutput(conditioning,images)

class DiffuEraser_Sampler(io.ComfyNode):
    @classmethod
    def define_schema(cls):
        return io.Schema(
            node_id="DiffuEraser_Sampler",
            display_name="DiffuEraser_Sampler",
            category="DiffuEraser",
            inputs=[
                io.Custom("DiffuEraser_Loader").Input("model"),
                io.Conditioning.Input("positive"),
                io.Conditioning.Input("conditioning"),
                io.Int.Input("steps", default=2, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Int.Input("seed", default=0, min=0, max=MAX_SEED,display_mode=io.NumberDisplay.number),
                io.Boolean.Input("save_result_video", default=False),
                io.Int.Input("unet_group", default=5, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Int.Input("brush_group", default=5, min=1, max=1024,step=1,display_mode=io.NumberDisplay.number),
                io.Boolean.Input("blended", default=False),
            ], 
            outputs=[
                io.Image.Output(display_name="image"),
               
                ],
             )
    @classmethod
    def execute(cls, model,positive,conditioning,steps,seed,save_result_video,unet_group,brush_group,blended) -> io.NodeOutput:
        # gc cf model
        cf_models=mm.loaded_models()
        try:
            for pipe in cf_models:   
                pipe.unpatch_model(device_to=torch.device("cpu"))
                print(f"Unpatching models.{pipe}")
        except: pass
        mm.soft_empty_cache()
        torch.cuda.empty_cache()
        max_gpu_memory = torch.cuda.max_memory_allocated()
        print(f"After Max GPU memory allocated: {max_gpu_memory / 1000 ** 3:.2f} GB")
        
        max_img_size=1920
        model.to(device) 
        model.pipeline.enable_xformers_memory_efficient_attention()
        apply_group_offloading(model.pipeline.unet, onload_device=torch.device("cuda"), offload_type="block_level", num_blocks_per_group=unet_group)
        apply_group_offloading(model.pipeline.brushnet, onload_device=torch.device("cuda"), offload_type="block_level", num_blocks_per_group=brush_group)
        image_list=model.forward( copy.deepcopy(conditioning["video_image"]), copy.deepcopy(conditioning["init_mask"]), copy.deepcopy(conditioning["prioris"]),folder_paths.get_output_directory(),positive,load_videobypath=False,
                                max_img_size = max_img_size, video_length=conditioning["video_length"], mask_dilation_iter=conditioning["mask_dilation_iter"],seed=seed,blended=blended,
                               num_inference_steps=steps,fps=conditioning["fps"],img_size=(conditioning["width"],conditioning["height"]),if_save_video=save_result_video)
        
        #model.to("cpu")
        gc.collect()
        torch.cuda.empty_cache()
        images=load_images(image_list)

        return io.NodeOutput(images)


from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/DiffuEraser_SM_Extension")
async def get_hello(request):
    return web.json_response("DiffuEraser_SM_Extension")

class DiffuEraser_SM_Extension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[io.ComfyNode]]:
        return [
            Propainter_Loader,
            DiffuEraser_Loader,
            DiffuEraser_PreData,
            Propainter_Sampler,
            DiffuEraser_Sampler,
        ]
async def comfy_entrypoint() -> DiffuEraser_SM_Extension:  # ComfyUI calls this to load your extension and its nodes.
    return DiffuEraser_SM_Extension()



================================================
FILE: example_workflows/differaser.json
================================================
{
  "id": "3da45669-6ef0-4ec2-a292-abe74e953ca2",
  "revision": 0,
  "last_node_id": 24,
  "last_link_id": 28,
  "nodes": [
    {
      "id": 17,
      "type": "SaveVideo",
      "pos": [
        4039.588713354701,
        1439.7671022987934
      ],
      "size": [
        478,
        420.6194116210936
      ],
      "flags": {},
      "order": 12,
      "mode": 0,
      "inputs": [
        {
          "name": "video",
          "type": "VIDEO",
          "link": 14
        }
      ],
      "outputs": [],
      "properties": {},
      "widgets_values": [
        "video/ComfyUI",
        "auto",
        "auto"
      ]
    },
    {
      "id": 6,
      "type": "DiffuEraser_PreData",
      "pos": [
        3639.4483963547445,
        1166.3094503184818
      ],
      "size": [
        272.4375,
        98
      ],
      "flags": {},
      "order": 5,
      "mode": 0,
      "inputs": [
        {
          "name": "images",
          "type": "IMAGE",
          "link": 18
        },
        {
          "name": "video_mask_image",
          "shape": 7,
          "type": "IMAGE",
          "link": 19
        },
        {
          "name": "video_mask",
          "shape": 7,
          "type": "MASK",
          "link": null
        }
      ],
      "outputs": [
        {
          "name": "conditioning",
          "type": "CONDITIONING",
          "links": [
            1
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "DiffuEraser_PreData"
      },
      "widgets_values": [
        "briaai/RMBG-2.0"
      ]
    },
    {
      "id": 10,
      "type": "CreateVideo",
      "pos": [
        4707.826211690106,
        1050.8378342212904
      ],
      "size": [
        270,
        78
      ],
      "flags": {},
      "order": 11,
      "mode": 0,
      "inputs": [
        {
          "name": "images",
          "type": "IMAGE",
          "link": 26
        },
        {
          "name": "audio",
          "shape": 7,
          "type": "AUDIO",
          "link": null
        },
        {
          "name": "fps",
          "type": "FLOAT",
          "widget": {
            "name": "fps"
          },
          "link": 23
        }
      ],
      "outputs": [
        {
          "name": "VIDEO",
          "type": "VIDEO",
          "links": [
            5
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "CreateVideo"
      },
      "widgets_values": [
        30
      ]
    },
    {
      "id": 19,
      "type": "VHS_LoadVideo",
      "pos": [
        3372.3171964640787,
        1474.4346575533937
      ],
      "size": [
        261.6533203125,
        310
      ],
      "flags": {},
      "order": 0,
      "mode": 0,
      "inputs": [
        {
          "name": "meta_batch",
          "shape": 7,
          "type": "VHS_BatchManager",
          "link": null
        },
        {
          "name": "vae",
          "shape": 7,
          "type": "VAE",
          "link": null
        }
      ],
      "outputs": [
        {
          "name": "IMAGE",
          "type": "IMAGE",
          "links": [
            19
          ]
        },
        {
          "name": "frame_count",
          "type": "INT",
          "links": null
        },
        {
          "name": "audio",
          "type": "AUDIO",
          "links": null
        },
        {
          "name": "video_info",
          "type": "VHS_VIDEOINFO",
          "links": null
        }
      ],
      "properties": {
        "Node name for S&R": "VHS_LoadVideo"
      },
      "widgets_values": {
        "video": "mask.mp4",
        "force_rate": 0,
        "custom_width": 0,
        "custom_height": 0,
        "frame_load_cap": 44,
        "skip_first_frames": 0,
        "select_every_nth": 1,
        "format": "AnimateDiff",
        "videopreview": {
          "hidden": false,
          "paused": false,
          "params": {
            "filename": "mask.mp4",
            "type": "input",
            "format": "video/mp4",
            "force_rate": 0,
            "custom_width": 0,
            "custom_height": 0,
            "frame_load_cap": 44,
            "skip_first_frames": 0,
            "select_every_nth": 1
          }
        }
      }
    },
    {
      "id": 20,
      "type": "VHS_VideoInfo",
      "pos": [
        3361.206167411345,
        1183.434691122729
      ],
      "size": [
        234.931640625,
        206
      ],
      "flags": {},
      "order": 6,
      "mode": 0,
      "inputs": [
        {
          "name": "video_info",
          "type": "VHS_VIDEOINFO",
          "link": 20
        }
      ],
      "outputs": [
        {
          "name": "source_fps🟨",
          "type": "FLOAT",
          "links": [
            21,
            22,
            23
          ]
        },
        {
          "name": "source_frame_count🟨",
          "type": "INT",
          "links": null
        },
        {
          "name": "source_duration🟨",
          "type": "FLOAT",
          "links": null
        },
        {
          "name": "source_width🟨",
          "type": "INT",
          "links": null
        },
        {
          "name": "source_height🟨",
          "type": "INT",
          "links": null
        },
        {
          "name": "loaded_fps🟦",
          "type": "FLOAT",
          "links": null
        },
        {
          "name": "loaded_frame_count🟦",
          "type": "INT",
          "links": null
        },
        {
          "name": "loaded_duration🟦",
          "type": "FLOAT",
          "links": null
        },
        {
          "name": "loaded_width🟦",
          "type": "INT",
          "links": null
        },
        {
          "name": "loaded_height🟦",
          "type": "INT",
          "links": null
        }
      ],
      "properties": {
        "Node name for S&R": "VHS_VideoInfo"
      },
      "widgets_values": {}
    },
    {
      "id": 18,
      "type": "VHS_LoadVideo",
      "pos": [
        3043.1471860575843,
        1001.7635755450315
      ],
      "size": [
        261.6533203125,
        459.82388026932085
      ],
      "flags": {},
      "order": 1,
      "mode": 0,
      "inputs": [
        {
          "name": "meta_batch",
          "shape": 7,
          "type": "VHS_BatchManager",
          "link": null
        },
        {
          "name": "vae",
          "shape": 7,
          "type": "VAE",
          "link": null
        }
      ],
      "outputs": [
        {
          "name": "IMAGE",
          "type": "IMAGE",
          "links": [
            18
          ]
        },
        {
          "name": "frame_count",
          "type": "INT",
          "links": null
        },
        {
          "name": "audio",
          "type": "AUDIO",
          "links": null
        },
        {
          "name": "video_info",
          "type": "VHS_VIDEOINFO",
          "links": [
            20
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "VHS_LoadVideo"
      },
      "widgets_values": {
        "video": "video.mp4",
        "force_rate": 0,
        "custom_width": 0,
        "custom_height": 0,
        "frame_load_cap": 44,
        "skip_first_frames": 0,
        "select_every_nth": 1,
        "format": "AnimateDiff",
        "videopreview": {
          "hidden": false,
          "paused": false,
          "params": {
            "filename": "video.mp4",
            "type": "input",
            "format": "video/mp4",
            "force_rate": 0,
            "custom_width": 0,
            "custom_height": 0,
            "frame_load_cap": 44,
            "skip_first_frames": 0,
            "select_every_nth": 1
          }
        }
      }
    },
    {
      "id": 11,
      "type": "SaveVideo",
      "pos": [
        4672.625436743371,
        1369.3215236855117
      ],
      "size": [
        474.35558837890585,
        445.77460030033916
      ],
      "flags": {},
      "order": 13,
      "mode": 0,
      "inputs": [
        {
          "name": "video",
          "type": "VIDEO",
          "link": 5
        }
      ],
      "outputs": [],
      "properties": {},
      "widgets_values": [
        "video/ComfyUI",
        "auto",
        "auto"
      ]
    },
    {
      "id": 7,
      "type": "Propainter_Loader",
      "pos": [
        3545.684605688686,
        944.4438779502053
      ],
      "size": [
        353.1111633300784,
        130
      ],
      "flags": {},
      "order": 2,
      "mode": 0,
      "inputs": [],
      "outputs": [
        {
          "name": "model",
          "type": "Propainter_Loader",
          "links": [
            2
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "Propainter_Loader"
      },
      "widgets_values": [
        "propainter\\ProPainter.pth",
        "propainter\\recurrent_flow_completion.pth",
        "propainter\\raft-things.pth",
        "cpu"
      ]
    },
    {
      "id": 9,
      "type": "Propainter_Sampler",
      "pos": [
        4019.9112536057933,
        1191.105716114276
      ],
      "size": [
        270,
        194
      ],
      "flags": {},
      "order": 8,
      "mode": 0,
      "inputs": [
        {
          "name": "model",
          "type": "Propainter_Loader",
          "link": 2
        },
        {
          "name": "conditioning",
          "type": "CONDITIONING",
          "link": 1
        },
        {
          "name": "fps",
          "type": "FLOAT",
          "link": 21
        }
      ],
      "outputs": [
        {
          "name": "conditioning",
          "type": "CONDITIONING",
          "links": [
            24
          ]
        },
        {
          "name": "images",
          "type": "IMAGE",
          "links": [
            15
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "Propainter_Sampler"
      },
      "widgets_values": [
        10,
        8,
        10,
        10,
        50
      ]
    },
    {
      "id": 21,
      "type": "DiffuEraser_Sampler",
      "pos": [
        4347.663465704763,
        1122.872132522008
      ],
      "size": [
        270,
        242
      ],
      "flags": {},
      "order": 9,
      "mode": 0,
      "inputs": [
        {
          "name": "model",
          "type": "DiffuEraser_Loader",
          "link": 25
        },
        {
          "name": "positive",
          "type": "CONDITIONING",
          "link": 27
        },
        {
          "name": "conditioning",
          "type": "CONDITIONING",
          "link": 24
        }
      ],
      "outputs": [
        {
          "name": "image",
          "type": "IMAGE",
          "links": [
            26
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "DiffuEraser_Sampler"
      },
      "widgets_values": [
        2,
        2145759323,
        "randomize",
        false,
        5,
        5,
        false
      ]
    },
    {
      "id": 22,
      "type": "CLIPTextEncode",
      "pos": [
        4152.07287977595,
        1005.1234869572979
      ],
      "size": [
        400,
        200
      ],
      "flags": {
        "collapsed": true
      },
      "order": 7,
      "mode": 0,
      "inputs": [
        {
          "name": "clip",
          "type": "CLIP",
          "link": 28
        }
      ],
      "outputs": [
        {
          "name": "CONDITIONING",
          "type": "CONDITIONING",
          "links": [
            27
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "CLIPTextEncode"
      },
      "widgets_values": [
        ""
      ]
    },
    {
      "id": 23,
      "type": "CLIPLoader",
      "pos": [
        3944.6667672230137,
        744.9330057919422
      ],
      "size": [
        270,
        106
      ],
      "flags": {},
      "order": 3,
      "mode": 0,
      "inputs": [],
      "outputs": [
        {
          "name": "CLIP",
          "type": "CLIP",
          "links": [
            28
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "CLIPLoader"
      },
      "widgets_values": [
        "clip_l.safetensors",
        "stable_diffusion",
        "default"
      ]
    },
    {
      "id": 16,
      "type": "CreateVideo",
      "pos": [
        3700.395868604978,
        1394.0939422767228
      ],
      "size": [
        270,
        78
      ],
      "flags": {},
      "order": 10,
      "mode": 0,
      "inputs": [
        {
          "name": "images",
          "type": "IMAGE",
          "link": 15
        },
        {
          "name": "audio",
          "shape": 7,
          "type": "AUDIO",
          "link": null
        },
        {
          "name": "fps",
          "type": "FLOAT",
          "widget": {
            "name": "fps"
          },
          "link": 22
        }
      ],
      "outputs": [
        {
          "name": "VIDEO",
          "type": "VIDEO",
          "links": [
            14
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "CreateVideo"
      },
      "widgets_values": [
        1
      ]
    },
    {
      "id": 3,
      "type": "DiffuEraser_Loader",
      "pos": [
        4256.633308684001,
        860.7258349094477
      ],
      "size": [
        395.03338256835923,
        86.033327178955
      ],
      "flags": {},
      "order": 4,
      "mode": 0,
      "inputs": [],
      "outputs": [
        {
          "name": "model",
          "type": "DiffuEraser_Loader",
          "links": [
            25
          ]
        }
      ],
      "properties": {
        "Node name for S&R": "DiffuEraser_Loader"
      },
      "widgets_values": [
        "sd-vae-ft-mse.safetensors",
        "pcm_sd15_smallcfg_2step_converted.safetensors"
      ]
    }
  ],
  "links": [
    [
      1,
      6,
      0,
      9,
      1,
      "CONDITIONING"
    ],
    [
      2,
      7,
      0,
      9,
      0,
      "Propainter_Loader"
    ],
    [
      5,
      10,
      0,
      11,
      0,
      "VIDEO"
    ],
    [
      14,
      16,
      0,
      17,
      0,
      "VIDEO"
    ],
    [
      15,
      9,
      1,
      16,
      0,
      "IMAGE"
    ],
    [
      18,
      18,
      0,
      6,
      0,
      "IMAGE"
    ],
    [
      19,
      19,
      0,
      6,
      1,
      "IMAGE"
    ],
    [
      20,
      18,
      3,
      20,
      0,
      "VHS_VIDEOINFO"
    ],
    [
      21,
      20,
      0,
      9,
      2,
      "FLOAT"
    ],
    [
      22,
      20,
      0,
      16,
      2,
      "FLOAT"
    ],
    [
      23,
      20,
      0,
      10,
      2,
      "FLOAT"
    ],
    [
      24,
      9,
      0,
      21,
      2,
      "CONDITIONING"
    ],
    [
      25,
      3,
      0,
      21,
      0,
      "DiffuEraser_Loader"
    ],
    [
      26,
      21,
      0,
      10,
      0,
      "IMAGE"
    ],
    [
      27,
      22,
      0,
      21,
      1,
      "CONDITIONING"
    ],
    [
      28,
      23,
      0,
      22,
      0,
      "CLIP"
    ]
  ],
  "groups": [],
  "config": {},
  "extra": {
    "ds": {
      "scale": 0.7203953397191334,
      "offset": [
        -2763.3971806441527,
        -566.8610555303999
      ]
    },
    "workflowRendererVersion": "LG",
    "frontendVersion": "1.41.21",
    "VHS_latentpreview": false,
    "VHS_latentpreviewrate": 0,
    "VHS_MetadataImage": true,
    "VHS_KeepIntermediate": true
  },
  "version": 0.4
}

================================================
FILE: libs/__init__.py
================================================



================================================
FILE: libs/brushnet_CA.py
================================================
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
)
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from .unet_2d_blocks import (
    CrossAttnDownBlock2D,
    DownBlock2D,
    UNetMidBlock2D,
    UNetMidBlock2DCrossAttn,
    get_down_block,
    get_mid_block,
    get_up_block,
    MidBlock2D
)

# from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from .unet_2d_condition import UNet2DConditionModel


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
class BrushNetOutput(BaseOutput):
    """
    The output of [`BrushNetModel`].

    Args:
        up_block_res_samples (`tuple[torch.Tensor]`):
            A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
            be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
            used to condition the original UNet's upsampling activations.
        down_block_res_samples (`tuple[torch.Tensor]`):
            A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
            be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
            used to condition the original UNet's downsampling activations.
        mid_down_block_re_sample (`torch.Tensor`):
            The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
            Output can be used to condition the original UNet's middle block activation.
    """

    up_block_res_samples: Tuple[torch.Tensor]
    down_block_res_samples: Tuple[torch.Tensor]
    mid_block_res_sample: torch.Tensor


class BrushNetModel(ModelMixin, ConfigMixin):
    """
    A BrushNet model.

    Args:
        in_channels (`int`, defaults to 4):
            The number of channels in the input sample.
        flip_sin_to_cos (`bool`, defaults to `True`):
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, defaults to 0):
            The frequency shift to apply to the time embedding.
        down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
            The tuple of upsample blocks to use.
        only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
        block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, defaults to 2):
            The number of layers per block.
        downsample_padding (`int`, defaults to 1):
            The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, defaults to 1):
            The scale factor to use for the mid block.
        act_fn (`str`, defaults to "silu"):
            The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups to use for the normalization. If None, normalization and activation layers is skipped
            in post-processing.
        norm_eps (`float`, defaults to 1e-5):
            The epsilon to use for the normalization.
        cross_attention_dim (`int`, defaults to 1280):
            The dimension of the cross attention features.
        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
        encoder_hid_dim (`int`, *optional*, defaults to None):
            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
            dimension to `cross_attention_dim`.
        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
        attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
            The dimension of the attention heads.
        use_linear_projection (`bool`, defaults to `False`):
        class_embed_type (`str`, *optional*, defaults to `None`):
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
        addition_embed_type (`str`, *optional*, defaults to `None`):
            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
            "text". "text" will use the `TextTimeEmbedding` layer.
        num_class_embeds (`int`, *optional*, defaults to 0):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
        upcast_attention (`bool`, defaults to `False`):
        resnet_time_scale_shift (`str`, defaults to `"default"`):
            Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
        projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
            The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
            `class_embed_type="projection"`.
        brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
            The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
        conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
            The tuple of output channel for each block in the `conditioning_embedding` layer.
        global_pool_conditions (`bool`, defaults to `False`):
            TODO(Patrick) - unused parameter.
        addition_embed_type_num_heads (`int`, defaults to 64):
            The number of heads to use for the `TextTimeEmbedding` layer.
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 4,
        conditioning_channels: int = 5,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str, ...] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
        up_block_types: Tuple[str, ...] = (
            "UpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
        ),
        only_cross_attention: Union[bool, Tuple[bool]] = False,
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
        layers_per_block: int = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
        norm_num_groups: Optional[int] = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: int = 1280,
        transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
        encoder_hid_dim: Optional[int] = None,
        encoder_hid_dim_type: Optional[str] = None,
        attention_head_dim: Union[int, Tuple[int, ...]] = 8,
        num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
        use_linear_projection: bool = False,
        class_embed_type: Optional[str] = None,
        addition_embed_type: Optional[str] = None,
        addition_time_embed_dim: Optional[int] = None,
        num_class_embeds: Optional[int] = None,
        upcast_attention: bool = False,
        resnet_time_scale_shift: str = "default",
        projection_class_embeddings_input_dim: Optional[int] = None,
        brushnet_conditioning_channel_order: str = "rgb",
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
        global_pool_conditions: bool = False,
        addition_embed_type_num_heads: int = 64,
    ):
        super().__init__()

        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = num_attention_heads or attention_head_dim

        # Check inputs
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

        # input
        conv_in_kernel = 3
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in_condition = nn.Conv2d(
            in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )

        # time
        time_embed_dim = block_out_channels[0] * 4
        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
        timestep_input_dim = block_out_channels[0]
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
        )

        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            encoder_hid_dim_type = "text_proj"
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )

        if encoder_hid_dim_type == "text_proj":
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        elif encoder_hid_dim_type == "text_image_proj":
            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )

        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
        else:
            self.encoder_hid_proj = None

        # class embedding
        if class_embed_type is None and num_class_embeds is not None:
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        else:
            self.class_embedding = None

        if addition_embed_type == "text":
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            else:
                text_time_embedding_from_dim = cross_attention_dim

            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
        elif addition_embed_type == "text_image":
            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
        elif addition_embed_type == "text_time":
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

        elif addition_embed_type is not None:
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")

        self.down_blocks = nn.ModuleList([])
        self.brushnet_down_blocks = nn.ModuleList([])

        if isinstance(only_cross_attention, bool):
            only_cross_attention = [only_cross_attention] * len(down_block_types)

        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)

        # down
        output_channel = block_out_channels[0]

        brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
        brushnet_block = zero_module(brushnet_block)
        self.brushnet_down_blocks.append(brushnet_block) #零卷积

        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                transformer_layers_per_block=transformer_layers_per_block[i],
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                cross_attention_dim=cross_attention_dim,
                num_attention_heads=num_attention_heads[i],
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
                downsample_padding=downsample_padding,
                use_linear_projection=use_linear_projection,
                only_cross_attention=only_cross_attention[i],
                upcast_attention=upcast_attention,
                resnet_time_scale_shift=resnet_time_scale_shift,
            )
            self.down_blocks.append(down_block)

            for _ in range(layers_per_block):
                brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
                brushnet_block = zero_module(brushnet_block)
                self.brushnet_down_blocks.append(brushnet_block) #零卷积

            if not is_final_block:
                brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
                brushnet_block = zero_module(brushnet_block)
                self.brushnet_down_blocks.append(brushnet_block)

        # mid
        mid_block_channel = block_out_channels[-1]

        brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
        brushnet_block = zero_module(brushnet_block)
        self.brushnet_mid_block = brushnet_block

        self.mid_block = get_mid_block(
                mid_block_type,
                transformer_layers_per_block=transformer_layers_per_block[-1],
                in_channels=mid_block_channel,
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
                cross_attention_dim=cross_attention_dim,
                num_attention_heads=num_attention_heads[-1],
                resnet_groups=norm_num_groups,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
        )

        # count how many layers upsample the images
        self.num_upsamplers = 0

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        reversed_num_attention_heads = list(reversed(num_attention_heads))
        reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
        only_cross_attention = list(reversed(only_cross_attention))

        output_channel = reversed_block_out_channels[0]
        
        self.up_blocks = nn.ModuleList([])
        self.brushnet_up_blocks = nn.ModuleList([])

        for i, up_block_type in enumerate(up_block_types):
            is_final_block = i == len(block_out_channels) - 1

            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block+1,
                transformer_layers_per_block=reversed_transformer_layers_per_block[i],
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim,
                add_upsample=add_upsample,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resolution_idx=i,
                resnet_groups=norm_num_groups,
                cross_attention_dim=cross_attention_dim,
                num_attention_heads=reversed_num_attention_heads[i],
                use_linear_projection=use_linear_projection,
                only_cross_attention=only_cross_attention[i],
                upcast_attention=upcast_attention,
                resnet_time_scale_shift=resnet_time_scale_shift,
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel
            
            for _ in range(layers_per_block+1):
                brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
                brushnet_block = zero_module(brushnet_block)
                self.brushnet_up_blocks.append(brushnet_block)

            if not is_final_block:
                brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
                brushnet_block = zero_module(brushnet_block)
                self.brushnet_up_blocks.append(brushnet_block)


    @classmethod
    def from_unet(
        cls,
        unet: UNet2DConditionModel,
        brushnet_conditioning_channel_order: str = "rgb",
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
        load_weights_from_unet: bool = True,
        conditioning_channels: int = 5,
    ):
        r"""
        Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].

        Parameters:
            unet (`UNet2DConditionModel`):
                The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
                where applicable.
        """
        transformer_layers_per_block = (
            unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
        )
        encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
        encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
        addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
        addition_time_embed_dim = (
            unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
        )

        brushnet = cls(
            in_channels=unet.config.in_channels,
            conditioning_channels=conditioning_channels,
            flip_sin_to_cos=unet.config.flip_sin_to_cos,
            freq_shift=unet.config.freq_shift,
            # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
            down_block_types=[
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ],
            # mid_block_type='MidBlock2D',
            mid_block_type="UNetMidBlock2DCrossAttn",
            # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
            up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
            only_cross_attention=unet.config.only_cross_attention,
            block_out_channels=unet.config.block_out_channels,
            layers_per_block=unet.config.layers_per_block,
            downsample_padding=unet.config.downsample_padding,
            mid_block_scale_factor=unet.config.mid_block_scale_factor,
            act_fn=unet.config.act_fn,
            norm_num_groups=unet.config.norm_num_groups,
            norm_eps=unet.config.norm_eps,
            cross_attention_dim=unet.config.cross_attention_dim,
            transformer_layers_per_block=transformer_layers_per_block,
            encoder_hid_dim=encoder_hid_dim,
            encoder_hid_dim_type=encoder_hid_dim_type,
            attention_head_dim=unet.config.attention_head_dim,
            num_attention_heads=unet.config.num_attention_heads,
            use_linear_projection=unet.config.use_linear_projection,
            class_embed_type=unet.config.class_embed_type,
            addition_embed_type=addition_embed_type,
            addition_time_embed_dim=addition_time_embed_dim,
            num_class_embeds=unet.config.num_class_embeds,
            upcast_attention=unet.config.upcast_attention,
            resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
            projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
            brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
            conditioning_embedding_out_channels=conditioning_embedding_out_channels,
        )

        if load_weights_from_unet:
            conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
            conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
            conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
            brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
            brushnet.conv_in_condition.bias=unet.conv_in.bias

            brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
            brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())

            if brushnet.class_embedding:
                brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())

            brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
            brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
            brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)

        return brushnet

    @property
    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnAddedKVProcessor()
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnProcessor()
        else:
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        self.set_attn_processor(processor)

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
    def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
        r"""
        Enable sliced attention computation.

        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
        several steps. This is useful for saving some memory in exchange for a small decrease in speed.

        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
                fn_recursive_retrieve_sliceable_dims(child)

        # retrieve number of attention layers
        for module in self.children():
            fn_recursive_retrieve_sliceable_dims(module)

        num_sliceable_layers = len(sliceable_head_dims)

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
            slice_size = num_sliceable_layers * [1]

        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )

        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)

    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
            module.gradient_checkpointing = value

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        brushnet_cond: torch.FloatTensor,
        conditioning_scale: float = 1.0,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guess_mode: bool = False,
        return_dict: bool = True,
    ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
        """
        The [`BrushNetModel`] forward method.

        Args:
            sample (`torch.FloatTensor`):
                The noisy input tensor.
            timestep (`Union[torch.Tensor, float, int]`):
                The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.Tensor`):
                The encoder hidden states.
            brushnet_cond (`torch.FloatTensor`):
                The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
            conditioning_scale (`float`, defaults to `1.0`):
                The scale factor for BrushNet outputs.
            class_labels (`torch.Tensor`, *optional*, defaults to `None`):
                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
            timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
                Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
                timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
                embeddings.
            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
                negative values to the attention scores corresponding to "discard" tokens.
            added_cond_kwargs (`dict`):
                Additional conditions for the Stable Diffusion XL UNet.
            cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
                A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
            guess_mode (`bool`, defaults to `False`):
                In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
                you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
            return_dict (`bool`, defaults to `True`):
                Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.

        Returns:
            [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
                If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
                returned where the first element is the sample tensor.
        """
        # check channel order
        channel_order = self.config.brushnet_conditioning_channel_order

        if channel_order == "rgb":
            # in rgb order by default
            ...
        elif channel_order == "bgr":
            brushnet_cond = torch.flip(brushnet_cond, dims=[1])
        else:
            raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")

        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)
        aug_emb = None

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

        if self.config.addition_embed_type is not None:
            if self.config.addition_embed_type == "text":
                aug_emb = self.add_embedding(encoder_hidden_states)

            elif self.config.addition_embed_type == "text_time":
                if "text_embeds" not in added_cond_kwargs:
                    raise ValueError(
                        f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                    )
                text_embeds = added_cond_kwargs.get("text_embeds")
                if "time_ids" not in added_cond_kwargs:
                    raise ValueError(
                        f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                    )
                time_ids = added_cond_kwargs.get("time_ids")
                time_embeds = self.add_time_proj(time_ids.flatten())
                time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

                add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
                add_embeds = add_embeds.to(emb.dtype)
                aug_emb = self.add_embedding(add_embeds)

        emb = emb + aug_emb if aug_emb is not None else emb

        # 2. pre-process
        brushnet_cond=torch.concat([sample,brushnet_cond],1)
        sample = self.conv_in_condition(brushnet_cond)


        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples  

        # 4. PaintingNet down blocks
        brushnet_down_block_res_samples = ()
        for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
            down_block_res_sample = brushnet_down_block(down_block_res_sample)
            brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)

        # 5. mid
        if self.mid_block is not None:
            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                sample = self.mid_block(
                    sample,
                    emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample = self.mid_block(sample, emb)

        # 6. BrushNet mid blocks
        brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)


        # 7. up
        up_block_res_samples = ()
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample, up_res_samples = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    upsample_size=upsample_size,
                    attention_mask=attention_mask,
                    return_res_samples=True
                )
            else:
                sample, up_res_samples = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    upsample_size=upsample_size,
                    return_res_samples=True
                )

            up_block_res_samples += up_res_samples

        # 8. BrushNet up blocks
        brushnet_up_block_res_samples = ()
        for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
            up_block_res_sample = brushnet_up_block(up_block_res_sample)
            brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)

        # 6. scaling
        if guess_mode and not self.config.global_pool_conditions:
            scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device)  # 0.1 to 1.0
            scales = scales * conditioning_scale

            brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
            brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
            brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
        else:
            brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
            brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
            brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]


        if self.config.global_pool_conditions:
            brushnet_down_block_res_samples = [
                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
            ]
            brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
            brushnet_up_block_res_samples = [
                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
            ]

        if not return_dict:
            return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)

        return BrushNetOutput(
            down_block_res_samples=brushnet_down_block_res_samples, 
            mid_block_res_sample=brushnet_mid_block_res_sample,
            up_block_res_samples=brushnet_up_block_res_samples
        )


def zero_module(module):
    for p in module.parameters():
        nn.init.zeros_(p)
    return module


================================================
FILE: libs/diffueraser.py
================================================
import gc
import copy
import cv2
import os
import numpy as np
import torch
import torchvision
import re
import random
from einops import repeat
from PIL import Image, ImageFilter
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    UniPCMultistepScheduler,
    LCMScheduler,
    StableDiffusionPipeline
)
from diffusers.schedulers import TCDScheduler
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoTokenizer, PretrainedConfig
from safetensors.torch import load_file
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .brushnet_CA import BrushNetModel
from .unet_2d_condition import UNet2DConditionModel
from .pipeline_diffueraser import StableDiffusionDiffuEraserPipeline


def extract_step_number(ckpt_name):
    # 使用正则表达式查找 "step" 前面的数字
    match = re.search(r'(\d+)-Step', ckpt_name)
    if match:
        return int(match.group(1))
    else:
        return 2

checkpoints = {
    "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
    "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
    "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
    "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
    "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
    "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
    "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
    "LCM-Like LoRA": [
        "pcm_{}_lcmlike_lora_converted.safetensors",
        4,
        0.0,
    ],
}

def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        try:    
            from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation # old diffusers version
            return RobertaSeriesModelWithTransformation
        except:
            print("Error: Could not import RobertaSeriesModelWithTransformation.")
            raise ValueError(f"{model_class} is not supported.")
       
    else:
        raise ValueError(f"{model_class} is not supported.")

def resize_frames(frames, size=None):
    if size is not None:
        out_size = size
        process_size = (out_size[0] - out_size[0] % 8, out_size[1] - out_size[1] % 8)
        frames = [f.resize(process_size) for f in frames]  
    else:
        out_size = frames[0].size
        process_size = (out_size[0] - out_size[0] % 8, out_size[1] - out_size[1] % 8)
        if not out_size == process_size:
            frames = [f.resize(process_size) for f in frames]  
    
    return frames

def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames):
    cap = cv2.VideoCapture(validation_mask)
    if not cap.isOpened():
        print("Error: Could not open mask video.")
        exit()
    mask_fps = cap.get(cv2.CAP_PROP_FPS)
    if mask_fps != fps:
        cap.release()
        raise ValueError("The frame rate of all input videos needs to be consistent.")

    masks = []
    masked_images = []
    idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:  
            break
        if(idx >= n_total_frames):
            break
        mask = Image.fromarray(frame[...,::-1]).convert('L')
        if mask.size != img_size:
            mask = mask.resize(img_size, Image.NEAREST)
        mask = np.asarray(mask)
        m = np.array(mask > 0).astype(np.uint8)
        m = cv2.erode(m,
                    cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
                    iterations=1)
        m = cv2.dilate(m,
                    cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
                    iterations=mask_dilation_iter)

        mask = Image.fromarray(m * 255)
        masks.append(mask)

        masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255))
        masked_image = Image.fromarray(masked_image.astype(np.uint8))
        masked_images.append(masked_image)

        idx += 1
    cap.release()

    return masks, masked_images

def read_priori(priori, fps, n_total_frames, img_size):
    cap = cv2.VideoCapture(priori)
    if not cap.isOpened():
        print("Error: Could not open video.")
        exit()
    priori_fps = cap.get(cv2.CAP_PROP_FPS)
   
    if (priori_fps - fps) > 1e-8:
        print(f"priori fps: {priori_fps}, fps: {fps}")
        cap.release()
        raise ValueError("The frame rate of all input videos needs to be consistent.")

    prioris=[]
    idx = 0
    while True:
        ret, frame = cap.read()
        if not ret: 
            break
        if(idx >= n_total_frames):
            break
        img = Image.fromarray(frame[...,::-1])
        if img.size != img_size:
            img = img.resize(img_size)
        prioris.append(img)
        idx += 1
    cap.release()

    os.remove(priori) # remove priori 

    return prioris

def read_video(validation_image, video_length, nframes, max_img_size):
    vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB
    fps = info['video_fps']
    n_total_frames = int(video_length * fps)
    n_clip = int(np.ceil(n_total_frames/nframes))

    frames = list(vframes.numpy())[:n_total_frames]
    frames = [Image.fromarray(f) for f in frames]
    max_size = max(frames[0].size)
    if(max_size<256):
        raise ValueError("The resolution of the uploaded video must be larger than 256x256.")
    if(max_size>4096):
        raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.")
    if max_size>max_img_size:
        ratio = max_size/max_img_size
        ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio))
        img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
        resize_flag=True
    elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0):
        img_size = frames[0].size
        resize_flag=False
    else:
        ratio_size = frames[0].size
        img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
        resize_flag=True
    if resize_flag:
        frames = resize_frames(frames, img_size)
        img_size = frames[0].size

    return frames, fps, img_size, n_clip, n_total_frames


class DiffuEraser:
    def __init__(self, device, ):
        self.device = device

    def load_model(self,repo, diffueraser_path, ckpt_path,original_config_file,ckpt="Normal CFG 4-Step",):
        self.noise_scheduler = DDPMScheduler.from_pretrained(repo, 
                subfolder="scheduler",
                prediction_type="v_prediction",
                timestep_spacing="trailing",
                rescale_betas_zero_snr=True
            )
        self.tokenizer = AutoTokenizer.from_pretrained(
                    repo,
                    subfolder="tokenizer",
                    use_fast=False,
                )
        vae_config=AutoencoderKL.load_config(os.path.join(repo,"vae/config.json"))
        self.vae=AutoencoderKL.from_config(vae_config)
        self.vae.load_state_dict(load_file(ckpt_path) if ckpt_path.endswith(".safetensors") else torch.load(ckpt_path,weights_only=False),strict=False)
        #self.vae=AutoencoderKL.from_single_file(ckpt_path,config=os.path.join(repo,"vae") )
        # try: 
        #     pipe = StableDiffusionPipeline.from_single_file(
        #     ckpt_path,config=repo, original_config=original_config_file)
        # except:
        #     pipe = StableDiffusionPipeline.from_single_file(
        #     ckpt_path, config=repo,original_config_file=original_config_file)

        # self.text_encoder = pipe.text_encoder
        #self.vae = pipe.vae
        #del pipe 
        gc.collect()
        torch.cuda.empty_cache()

        self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet")
        self.unet_main = UNetMotionModel.from_pretrained(
            diffueraser_path, subfolder="unet_main",
        )
        ## set pipeline
        self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained(
            repo,
            vae=self.vae,
            text_encoder=None,
            tokenizer=self.tokenizer,
            unet=self.unet_main,
            brushnet=self.brushnet,
            safety_checker=None,#no need 
        ).to(self.device, torch.float16)
        # self.vae=None
        # self.text_encoder=None
        self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
        self.pipeline.set_progress_bar_config(disable=True)

        self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
        self.vae_scale_factor = 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
        
        try:
            self.pipeline.load_lora_weights(pretrained_model_name_or_path_or_dict=ckpt)
            print("Loaded lora from", ckpt)
        except Exception as e:
            print(f"Failed to apply LoRA {str(e)}")
            pass
        
        if "lcmlike" in ckpt.lower():
            self.pipeline.scheduler = LCMScheduler()
            self.num_inference_steps= 4
        else:
            self.pipeline.scheduler = TCDScheduler(
                    num_train_timesteps=1000,
                    beta_start=0.00085,
                    beta_end=0.012,
                    beta_schedule="scaled_linear",
                    timestep_spacing="trailing",
                )
            self.num_inference_steps=extract_step_number(ckpt)


        #self.num_inference_steps = checkpoints[ckpt][1]
        
        if "normal" in ckpt.lower():
            self.guidance_scale = 7.5
        else:
            self.guidance_scale = 0
        #self.guidance_scale = 0


    def to(self, device):
        self.device=device
        self.pipeline.to(device)

    def forward(self, validation_image, validation_mask, prioris, output_path,positive,load_videobypath=False,
                max_img_size = 1280, video_length=2, mask_dilation_iter=4,
                nframes=22, seed=None, revision = None, guidance_scale=None, blended=True,num_inference_steps=None,fps=24,img_size=(512, 512),if_save_video=False):
        validation_prompt = ""  # 
        guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale
        num_inference_steps_final = self.num_inference_steps if num_inference_steps==None else num_inference_steps

        if (max_img_size<256 or max_img_size>1920):
            raise ValueError("The max_img_size must be larger than 256, smaller than 1920.")

        ################ read input video ################ 
        if load_videobypath:
            frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size)
        else:
            frames=validation_image
            n_total_frames=len(validation_image)
            n_clip = int(np.ceil(n_total_frames/nframes))
        video_len = len(frames)
        #frames[0].save("input0.png")
        ################     read mask    ################ 
        if load_videobypath:
            validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames)
        else:
            validation_masks_list=[i.convert('L') for i in validation_mask.copy()]
            validation_images_input=[]
            validation_masks_input=[]
            for idx ,mask in enumerate(validation_masks_list):
                #mask = Image.fromarray(i[...,::-1]).convert('L')
                if mask.size != img_size:
                    mask = mask.resize(img_size, Image.NEAREST)
                mask = np.asarray(mask)
                m = np.array(mask > 0).astype(np.uint8)
                m = cv2.erode(m,
                            cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
                            iterations=1)
                m = cv2.dilate(m,
                            cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
                            iterations=mask_dilation_iter)

                mask = Image.fromarray(m * 255)
                validation_masks_input.append(mask)
                masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255))
                masked_image = Image.fromarray(masked_image.astype(np.uint8))
                validation_images_input.append(masked_image)
       
        ################    read priori   ################  
        #validation_images_input[0].save("input1.png")
        #prioris = read_priori(priori, fps, n_total_frames, img_size)
        if prioris[0].size != img_size:
            prioris = [img.resize(img_size) for img in prioris]
        ## recheck
        n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris))
        if(n_total_frames<22):
            raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.")
        validation_masks_input = validation_masks_input[:n_total_frames]
        validation_images_input = validation_images_input[:n_total_frames]
        frames = frames[:n_total_frames]
        prioris = prioris[:n_total_frames]

        prioris = resize_frames(prioris)
        validation_masks_input = resize_frames(validation_masks_input)
        validation_images_input = resize_frames(validation_images_input)
        resized_frames = resize_frames(frames)
        #resized_frames[0].save("input2.png")

        ##############################################
        # DiffuEraser inference
        ##############################################
        print("DiffuEraser inference...")
        if seed is None:
            generator = None
        else:
            generator = torch.Generator(device=self.device).manual_seed(seed)

        ## random noise
        real_video_length = len(validation_images_input)
        tar_width, tar_height = validation_images_input[0].size 
        shape = (
            nframes,
            4,
            tar_height//8,
            tar_width//8
        )

        if self.unet_main is not None:
            prompt_embeds_dtype = self.unet_main.dtype
        else:
            prompt_embeds_dtype = torch.float16
        noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator) 
        noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...]
        
        ################  prepare priori  ################
        images_preprocessed = []
        for image in prioris:
            image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32)
            image = image.to(device=torch.device(self.device), dtype=torch.float16)
            images_preprocessed.append(image)
        pixel_values = torch.cat(images_preprocessed)

        with torch.no_grad():
            pixel_values = pixel_values.to(dtype=torch.float16)
            latents = []
            num=4
            for i in range(0, pixel_values.shape[0], num):
                latents.append(self.pipeline.vae.encode(pixel_values[i : i + num]).latent_dist.sample())      
            latents = torch.cat(latents, dim=0)  
        latents = latents * self.pipeline.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
        self.pipeline.vae.to("cpu")
        torch.cuda.empty_cache()  
        timesteps = torch.tensor([0], device=self.device)
        timesteps = timesteps.long()

        validation_masks_input_ori = copy.deepcopy(validation_masks_input)
        resized_frames_ori = copy.deepcopy(resized_frames)

        ################  Pre-inference  ################
        if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2
            ## sample
            step = n_total_frames / nframes
            sample_index = [int(i * step) for i in range(nframes)]
            sample_index = sample_index[:22]
            validation_masks_input_pre = [validation_masks_input[i] for i in sample_index]
            validation_images_input_pre = [validation_images_input[i] for i in sample_index]
            latents_pre = torch.stack([latents[i] for i in sample_index])

            ## add proiri
            noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps) 
            latents_pre = noisy_latents_pre

            with torch.no_grad():
                latents_pre_out = self.pipeline(
                    num_frames=nframes, 
                    prompt=None, 
                    images=validation_images_input_pre, 
                    masks=validation_masks_input_pre, 
                    prompt_embeds=positive[0][0], 
                    num_inference_steps=num_inference_steps_final, 
                    generator=generator,
                    guidance_scale=guidance_scale_final,
                    latents=latents_pre,
                ).latents
            torch.cuda.empty_cache()  

            def decode_latents(latents, weight_dtype):
                latents = 1 / self.pipeline.vae.config.scaling_factor * latents
                video = []
                for t in range(latents.shape[0]):
                    video.append(self.pipeline.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
                video = torch.concat(video, dim=0)
                # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
                video = video.float()
                return video
            with torch.no_grad():
                video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
                images_pre_out  = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
            torch.cuda.empty_cache()  

            ## replace input frames with updated frames
            black_image = Image.new('L', validation_masks_input[0].size, color=0)
            for i,index in enumerate(sample_index):
                latents[index] = latents_pre_out[i]
                validation_masks_input[index] = black_image
                validation_images_input[index] = images_pre_out[i]
                resized_frames[index] = images_pre_out[i]
          
        else:
            latents_pre_out=None
            sample_index=None
        gc.collect()
        torch.cuda.empty_cache()

        ################  Frame-by-frame inference  ################
        ## add priori
        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) 
        latents = noisy_latents
        with torch.no_grad():
            images = self.pipeline(
                num_frames=nframes, 
                prompt=None, 
                images=validation_images_input, 
                masks=validation_masks_input,
                prompt_embeds=positive[0][0], 
                num_inference_steps=num_inference_steps_final, 
                generator=generator,
                guidance_scale=guidance_scale_final,
                latents=latents,
            ).frames
        images = images[:real_video_length]

        gc.collect()
        torch.cuda.empty_cache()

        ################ Compose ################
        binary_masks = validation_masks_input_ori
        mask_blurreds = []
        if blended:
            for i in range(len(binary_masks)):
                mask_array = np.array(binary_masks[i])
                mask_blurred = morphological_edge_blur(np.array(mask_array), sigma=2.0, edge_width=3)       
                #mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), blur_kernel, 0)/255.
                binary_mask = 1-(1-mask_array/255.) * (1-mask_blurred)
                mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8)))
            binary_masks = mask_blurreds

            comp_frames = []
            for i in range(len(images)):
                mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255.
                img = (np.array(images[i]).astype(np.uint8) * mask + np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8)
                comp_frames.append(Image.fromarray(img))
        else:
            comp_frames = simple_flicker_smoothing(images, alpha=0.15)

        if if_save_video:
            default_fps = fps
            prefix = ''.join(random.choice("0123456789") for _ in range(6))
            priori_path = os.path.join(output_path, f"priori_{prefix}.mp4")        
            os.makedirs(os.path.dirname(priori_path), exist_ok=True)
            
            writer = cv2.VideoWriter(priori_path, cv2.VideoWriter_fourcc(*"mp4v"),
                                default_fps, comp_frames[0].size)
            for f in range(real_video_length):
                img = np.array(comp_frames[f]).astype(np.uint8)
                writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            writer.release()
        ################################

        return comp_frames
def simple_flicker_smoothing(frames, alpha=0.1):
    """
    简单的闪烁平滑,最小化对动态内容的影响
    """
    if len(frames) < 2:
        return frames
    
    smoothed_frames = [frames[0]]
    
    for i in range(1, len(frames)):
        current = np.array(frames[i]).astype(np.float32)
        previous = np.array(frames[i-1]).astype(np.float32)
        
        # 只对变化很小的像素进行平滑(可能是闪烁)
        diff = np.abs(current - previous)
        static_mask = (diff < 10.0).astype(np.float32)  # 阈值可根据需要调整
        
        # 只在静态区域应用轻微平滑
        smoothed = previous * alpha * static_mask + current * (1 - alpha * static_mask)
        
        smoothed_frames.append(Image.fromarray(np.clip(smoothed, 0, 255).astype(np.uint8)))
    
    return smoothed_frames

def morphological_edge_blur(mask, sigma=3.0, edge_width=5):
    """
    使用形态学操作提取边缘并只模糊边缘区域
    """
    if mask.dtype != np.float32:
        mask_float = mask.astype(np.float32)
    else:
        mask_float = mask.copy()
    
    # 转换为二值图像
    binary_mask = (mask_float > 0.5).astype(np.uint8)
    
    if not np.any(binary_mask):
        return mask_float
    
    # 创建边缘遮罩
    # 腐蚀操作缩小遮罩
    kernel_size = max(3, edge_width)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    eroded = cv2.erode(binary_mask, kernel, iterations=1)
    
    # 边缘 = 原始遮罩 - 腐蚀后的遮罩
    edge_mask = binary_mask - eroded
    
    # 只对边缘区域进行高斯模糊
    edge_region = mask_float * edge_mask.astype(np.float32)
    
    # 模糊边缘区域
    ksize = int(2 * np.ceil(3 * sigma) + 1)
    ksize = max(3, min(101, ksize if ksize % 2 == 1 else ksize + 1))
    blurred_edges = cv2.GaussianBlur(edge_region, (ksize, ksize), sigmaX=sigma, sigmaY=sigma)
    
    # 合成结果:内部保持原值,边缘使用模糊值
    inner_region = mask_float * eroded.astype(np.float32)
    result = inner_region + blurred_edges
    
    return np.clip(result, 0.0, 1.0)


================================================
FILE: libs/pipeline_diffueraser.py
================================================
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
from einops import rearrange, repeat
from dataclasses import dataclass
import copy
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    USE_PEFT_BACKEND,
    deprecate,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
    BaseOutput
)
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    UniPCMultistepScheduler,
)

from .unet_2d_condition import UNet2DConditionModel
from .brushnet_CA import BrushNetModel


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

def get_frames_context_swap(total_frames=192, overlap=4, num_frames_per_clip=24):
    if total_frames<num_frames_per_clip:
        num_frames_per_clip = total_frames
    context_list = []
    context_list_swap = []
    for i in range(1, 2):  # i=1
        sample_interval = np.array(range(0,total_frames,i))
        n = len(sample_interval)
        if n>num_frames_per_clip:
            ## [0,num_frames_per_clip-1], [num_frames_per_clip, 2*num_frames_per_clip-1]....
            for k in range(0,n-num_frames_per_clip,num_frames_per_clip-overlap):
                context_list.append(sample_interval[k:k+num_frames_per_clip])
            if k+num_frames_per_clip < n and i==1:
                context_list.append(sample_interval[n-num_frames_per_clip:n])
            context_list_swap.append(sample_interval[0:num_frames_per_clip])
            for k in range(num_frames_per_clip//2, n-num_frames_per_clip, num_frames_per_clip-overlap):
                context_list_swap.append(sample_interval[k:k+num_frames_per_clip])
            if k+num_frames_per_clip < n and i==1:
                context_list_swap.append(sample_interval[n-num_frames_per_clip:n])
        if n==num_frames_per_clip:
            context_list.append(sample_interval[n-num_frames_per_clip:n])
            context_list_swap.append(sample_interval[n-num_frames_per_clip:n])
    return context_list, context_list_swap

@dataclass
class DiffuEraserPipelineOutput(BaseOutput):
    frames: Union[torch.Tensor, np.ndarray]
    latents: Union[torch.Tensor, np.ndarray]

class StableDiffusionDiffuEraserPipeline(
    DiffusionPipeline,
    StableDiffusionMixin,
    TextualInversionLoaderMixin,
    LoraLoaderMixin,
    IPAdapterMixin,
    FromSingleFileMixin,
):
    r"""
    Pipeline for video inpainting using Video Diffusion Model with BrushNet guidance.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
        brushnet ([`BrushNetModel`]`):
            Provides additional conditioning to the `unet` during the denoising process.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
    """

    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
    _exclude_from_cpu_offload = ["safety_checker"]
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        brushnet: BrushNetModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        image_encoder: CLIPVisionModelWithProjection = None,
        requires_safety_checker: bool = True,
    ):
        super().__init__()

        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            brushnet=brushnet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            image_encoder=image_encoder,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
        self.register_to_config(requires_safety_checker=requires_safety_checker)

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        lora_scale: Optional[float] = None,
        **kwargs,
    ):
        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

        prompt_embeds_tuple = self.encode_prompt(
            prompt=prompt,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
            **kwargs,
        )

        # concatenate for backwards comp
        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

        return prompt_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
    def encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
    ):
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            lora_scale (`float`, *optional*):
                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
        """
        # set lora scale so that monkey patched LoRA
        # function of text encoder can correctly access it
        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
            self._lora_scale = lora_scale

            # dynamically adjust the LoRA scale
            if not USE_PEFT_BACKEND:
                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
            else:
                scale_lora_layers(self.text_encoder, lora_scale)

        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            # textual inversion: process multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None

            if clip_skip is None:
                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
                prompt_embeds = prompt_embeds[0]
            else:
                prompt_embeds = self.text_encoder(
                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
                )
                # Access the `hidden_states` first, that contains a tuple of
                # all the hidden states from the encoder layers. Then index into
                # the tuple to access the hidden states from the desired layer.
                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
                # We also need to apply the final LayerNorm here to not mess with the
                # representations. The `last_hidden_states` that we typically use for
                # obtaining the final prompt representations passes through the LayerNorm
                # layer.
                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)

        if self.text_encoder is not None:
            prompt_embeds_dtype = self.text_encoder.dtype
        elif self.unet is not None:
            prompt_embeds_dtype = self.unet.dtype
        else:
            prompt_embeds_dtype = prompt_embeds.dtype

        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif prompt is not None and type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            # textual inversion: process multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            negative_prompt_embeds = negative_prompt_embeds[0]

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
            # Retrieve the original scale by scaling back the LoRA layers
            unscale_lora_layers(self.text_encoder, lora_scale)

        return prompt_embeds, negative_prompt_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
        dtype = next(self.image_encoder.parameters()).dtype

        if not isinstance(image, torch.Tensor):
            image = self.feature_extractor(image, return_tensors="pt").pixel_values

        image = image.to(device=device, dtype=dtype)
        if output_hidden_states:
            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
            uncond_image_enc_hidden_states = self.image_encoder(
                torch.zeros_like(image), output_hidden_states=True
            ).hidden_states[-2]
            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                num_images_per_prompt, dim=0
            )
            return image_enc_hidden_states, uncond_image_enc_hidden_states
        else:
            image_embeds = self.image_encoder(image).image_embeds
            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
            uncond_image_embeds = torch.zeros_like(image_embeds)

            return image_embeds, uncond_image_embeds
        
    def decode_latents(self, latents, weight_dtype):
        latents = 1 / self.vae.config.scaling_factor * latents
        video = []
        for t in range(latents.shape[0]):
            video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
        video = torch.concat(video, dim=0)
        
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        video = video.float()
        return video

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
    def prepare_ip_adapter_image_embeds(
        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
    ):
        if ip_adapter_image_embeds is None:
            if not isinstance(ip_adapter_image, list):
                ip_adapter_image = [ip_adapter_image]

            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                raise ValueError(
                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                )

            image_embeds = []
            for single_ip_adapter_image, image_proj_layer in zip(
                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
            ):
                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                single_image_embeds, single_negative_image_embeds = self.encode_image(
                    single_ip_adapter_image, device, 1, output_hidden_state
                )
                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
                single_negative_image_embeds = torch.stack(
                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
                )

                if do_classifier_free_guidance:
                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
                    single_image_embeds = single_image_embeds.to(device)

                image_embeds.append(single_image_embeds)
        else:
            repeat_dims = [1]
            image_embeds = []
            for single_image_embeds in ip_adapter_image_embeds:
                if do_classifier_free_guidance:
                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                    single_image_embeds = single_image_embeds.repeat(
                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
                    )
                    single_negative_image_embeds = single_negative_image_embeds.repeat(
                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
                    )
                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
                else:
                    single_image_embeds = single_image_embeds.repeat(
                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
                    )
                image_embeds.append(single_image_embeds)

        return image_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is None:
            has_nsfw_concept = None
        else:
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        return image, has_nsfw_concept

    # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
    def decode_latents(self, latents, weight_dtype):
        if self.vae.device!= latents.device:
            self.vae.to(latents.device)
        latents = 1 / self.vae.config.scaling_factor * latents
        video = []
        for t in range(latents.shape[0]):
            video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
        video = torch.concat(video, dim=0)
        
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        video = video.float()
        return video

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

    def check_inputs(
        self,
        prompt,
        images,
        masks,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        ip_adapter_image=None,
        ip_adapter_image_embeds=None,
        brushnet_conditioning_scale=1.0,
        control_guidance_start=0.0,
        control_guidance_end=1.0,
        callback_on_step_end_tensor_inputs=None,
    ):
        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )

        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

        # Check `image`
        is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
            self.brushnet, torch._dynamo.eval_frame.OptimizedModule
        )
        if (
            isinstance(self.brushnet, BrushNetModel)
            or is_compiled
            and isinstance(self.brushnet._orig_mod, BrushNetModel)
        ):
            self.check_image(images, masks, prompt, prompt_embeds)
        else:
            assert False

        # Check `brushnet_conditioning_scale`
        if (
            isinstance(self.brushnet, BrushNetModel)
            or is_compiled
            and isinstance(self.brushnet._orig_mod, BrushNetModel)
        ):
            if not isinstance(brushnet_conditioning_scale, float):
                raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.")
        else:
            assert False

        if not isinstance(control_guidance_start, (tuple, list)):
            control_guidance_start = [control_guidance_start]

        if not isinstance(control_guidance_end, (tuple, list)):
            control_guidance_end = [control_guidance_end]

        if len(control_guidance_start) != len(control_guidance_end):
            raise ValueError(
                f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
            )

        for start, end in zip(control_guidance_start, control_guidance_end):
            if start >= end:
                raise ValueError(
                    f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
                )
            if start < 0.0:
                raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
            if end > 1.0:
                raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")

        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
            raise ValueError(
                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
            )

        if ip_adapter_image_embeds is not None:
            if not isinstance(ip_adapter_image_embeds, list):
                raise ValueError(
                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
                )
            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
                raise ValueError(
                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
                )

    def check_image(self, images, masks, prompt, prompt_embeds):
        for image in images:
            image_is_pil = isinstance(image, PIL.Image.Image)
            image_is_tensor = isinstance(image, torch.Tensor)
            image_is_np = isinstance(image, np.ndarray)
            image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
            image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
            image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)

            if (
                not image_is_pil
                and not image_is_tensor
                and not image_is_np
                and not image_is_pil_list
                and not image_is_tensor_list
                and not image_is_np_list
            ):
                raise TypeError(
                    f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
                )
        for mask in masks:
            mask_is_pil = isinstance(mask, PIL.Image.Image)
            mask_is_tensor = isinstance(mask, torch.Tensor)
            mask_is_np = isinstance(mask, np.ndarray)
            mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image)
            mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor)
            mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)

            if (
                not mask_is_pil
                and not mask_is_tensor
                and not mask_is_np
                and not mask_is_pil_list
                and not mask_is_tensor_list
                and not mask_is_np_list
            ):
                raise TypeError(
                    f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}"
                )

        if image_is_pil:
            image_batch_size = 1
        else:
            image_batch_size = len(image)

        if prompt is not None and isinstance(prompt, str):
            prompt_batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            prompt_batch_size = len(prompt)
        elif prompt_embeds is not None:
            prompt_batch_size = prompt_embeds.shape[0]

        if image_batch_size != 1 and image_batch_size != prompt_batch_size:
            raise ValueError(
                f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
            )

    def prepare_image(
        self,
        images,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        images_new = []
        for image in images:
            image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
            image_batch_size = image.shape[0]

            if image_batch_size == 1:
                repeat_by = batch_size
            else:
                # image batch size is the same as prompt batch size
                repeat_by = num_images_per_prompt

            image = image.repeat_interleave(repeat_by, dim=0)

            image = image.to(device=device, dtype=dtype)

            # if do_classifier_free_guidance and not guess_mode:
            #     image = torch.cat([image] * 2)
            images_new.append(image)

        return images_new

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
    def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
        # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
        #b,c,n,h,w
        shape = (
            batch_size,
            num_channels_latents,
            num_frames,
            height // self.vae_scale_factor, 
            width // self.vae_scale_factor
        )
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            noise = rearrange(randn_tensor(shape, generator=generator, device=device, dtype=dtype), "b c t h w -> (b t) c h w")
        else:
            noise = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = noise * self.scheduler.init_noise_sigma
        return latents, noise
    
    @staticmethod
    def temp_blend(a, b, overlap):
        factor = torch.arange(overlap).to(b.device).view(overlap, 1, 1, 1) / (overlap - 1)
        a[:overlap, ...] = (1 - factor) * a[:overlap, ...] + factor * b[:overlap, ...]
        a[overlap:, ...] = b[overlap:, ...]
        return a

    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
        """
        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

        Args:
            timesteps (`torch.Tensor`):
                generate embedding vectors at these timesteps
            embedding_dim (`int`, *optional*, defaults to 512):
                dimension of the embeddings to generate
            dtype:
                data type of the generated embeddings

        Returns:
            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
        """
        assert len(w.shape) == 1
        w = w * 1000.0

        half_dim = embedding_dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
        emb = w.to(dtype)[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1))
        assert emb.shape == (w.shape[0], embedding_dim)
        return emb

    @property
    def guidance_scale(self):
        return self._guidance_scale

    @property
    def clip_skip(self):
        return self._clip_skip

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    @property
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

    @property
    def cross_attention_kwargs(self):
        return self._cross_attention_kwargs

    @property
    def num_timesteps(self):
        return self._num_timesteps

    # based on BrushNet: https://github.com/TencentARC/BrushNet/blob/main/src/diffusers/pipelines/brushnet/pipeline_brushnet.py
    @torch.no_grad()
    def __call__(
        self,
        num_frames: Optional[int] = 24,
        prompt: Union[str, List[str]] = None,
        images: PipelineImageInput = None, ##masked images
        masks: PipelineImageInput = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        timesteps: List[int] = None,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        brushnet_conditioning_scale: Union[float, List[float]] = 1.0,
        guess_mode: bool = False,
        control_guidance_start: Union[float, List[float]] = 0.0,
        control_guidance_end: Union[float, List[float]] = 1.0,
        clip_skip: Optional[int] = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        **kwargs,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
                The BrushNet branch input condition to provide guidance to the `unet` for generation. 
            mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
                The BrushNet branch input condition to provide guidance to the `unet` for generation. 
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`List[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
                if `do_classifier_free_guidance` is set to `True`.
                If not provided, embeddings are computed from the `ip_adapter_image` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            callback (`Callable`, *optional*):
                A function that calls every `callback_steps` steps during inference. The function is called with the
                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function is called. If not specified, the callback is called at
                every step.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
                The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added
                to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set
                the corresponding scale as a list.
            guess_mode (`bool`, *optional*, defaults to `False`):
                The BrushNet encoder tries to recognize the content of the input image even if you remove all
                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
                The percentage of total steps at which the BrushNet starts applying.
            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
                The percentage of total steps at which the BrushNet stops applying.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeine class.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
        """

        callback = kwargs.pop("callback", None)
        callback_steps = kwargs.pop("callback_steps", None)

        if callback is not None:
            deprecate(
                "callback",
                "1.0.0",
                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
            )
        if callback_steps is not None:
            deprecate(
                "callback_steps",
                "1.0.0",
                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
            )

        brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet

        # align format for control guidance
        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
            control_guidance_start, control_guidance_end = (
                [control_guidance_start],
                [control_guidance_end],
            )

        # 1. Check inputs. Raise error if not correct
        # self.check_inputs(
        #     prompt,
        #     images,
        #     masks,
        #     callback_steps,
        #     negative_prompt,
        #     prompt_embeds,
        #     negative_prompt_embeds,
        #     ip_adapter_image,
        #     ip_adapter_image_embeds,
        #     brushnet_conditioning_scale,
        #     control_guidance_start,
        #     control_guidance_end,
        #     callback_on_step_end_tensor_inputs,
        # )

        self._guidance_scale = guidance_scale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        global_pool_conditions = (
            brushnet.config.global_pool_conditions
            if isinstance(brushnet, BrushNetModel)
            else brushnet.nets[0].config.global_pool_conditions
        )
        guess_mode = guess_mode or global_pool_conditions
        video_length = len(images)
        if prompt is not None:
            # 3. Encode input prompt
            text_encoder_lora_scale = (
                self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
            )
            if self.text_encoder.device != device:
                self.text_encoder.to(device)
            prompt_embeds, negative_prompt_embeds = self.encode_prompt(
                prompt,
                device,
                num_images_per_prompt,
                self.do_classifier_free_guidance,
                negative_prompt,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
                lora_scale=text_encoder_lora_scale,
                clip_skip=self.clip_skip,
            )
            self.text_encoder.to("cpu")
        else:
            prompt_embeds= prompt_embeds.to(self.unet.device,self.unet.dtype)
            print(prompt_embeds.shape)
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
            image_embeds = self.prepare_ip_adapter_image_embeds(
                ip_adapter_image,
                ip_adapter_image_embeds,
                device,
                batch_size * num_images_per_prompt,
                self.do_classifier_free_guidance,
            )

        # 4. Prepare image
        if isinstance(brushnet, BrushNetModel):
            images = self.prepare_image(
                images=images,
                width=width,
                height=height,
                batch_size=batch_size * num_images_per_prompt,
                num_images_per_prompt=num_images_per_prompt,
                device=device,
                dtype=brushnet.dtype,
                do_classifier_free_guidance=self.do_classifier_free_guidance,
                guess_mode=guess_mode,
            )
            original_masks = self.prepare_image(
                images=masks,
                width=width,
                height=height,
                batch_size=batch_size * num_images_per_prompt,
                num_images_per_prompt=num_images_per_prompt,
                device=device,
                dtype=brushnet.dtype,
                do_classifier_free_guidance=self.do_classifier_free_guidance,
                guess_mode=guess_mode,
            )
            original_masks_new = []
            for original_mask in original_masks:
                original_mask=(original_mask.sum(1)[:,None,:,:] < 0).to(images[0].dtype) 
                original_masks_new.append(original_mask)
            original_masks = original_masks_new
            
            height, width = images[0].shape[-2:]
        else:
            assert False

        # 5. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
        self._num_timesteps = len(timesteps)
        if self.vae.device != device:
            self.vae.to(device)
        # 6. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents, noise = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            num_frames,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        
        # 6.1 prepare condition latents
        images = torch.cat(images)
        images = images.to(dtype=images[0].dtype)
        conditioning_latents = []
        num=4
        for i in range(0, images.shape[0], num):
            conditioning_latents.append(self.vae.encode(images[i : i + num]).latent_dist.sample())
        conditioning_latents = torch.cat(conditioning_latents, dim=0)
        self.vae.to("cpu")
        conditioning_latents = conditioning_latents * self.vae.config.scaling_factor  #[(f c h w],c2=4

        original_masks = torch.cat(original_masks) 
        masks = torch.nn.functional.interpolate(
            original_masks, 
            size=(
                latents.shape[-2], 
                latents.shape[-1]
            )
        ) ##[ f c h w],c=1

        conditioning_latents=torch.concat([conditioning_latents,masks],1)

        # 6.5 Optionally get Guidance Scale Embedding
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7.1 Add image embeds for IP-Adapter
        added_cond_kwargs = (
            {"image_embeds": image_embeds}
            if ip_adapter_image is not None or ip_adapter_image_embeds is not None
            else None
        )

        # 7.2 Create tensor stating which brushnets to keep
        brushnet_keep = []
        for i in range(len(timesteps)):
            keeps = [
                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                for s, e in zip(control_guidance_start, control_guidance_end)
            ]
            brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps)


        overlap = num_frames//4
        context_list, context_list_swap = get_frames_context_swap(video_length, overlap=overlap, num_frames_per_clip=num_frames)
        scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list)
        scheduler_status_swap = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list_swap)
        count = torch.zeros_like(latents)
        value = torch.zeros_like(latents)
        
        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        is_unet_compiled = is_compiled_module(self.unet)
        is_brushnet_compiled = is_compiled_module(self.brushnet)
        is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                
                count.zero_()
                value.zero_()
                ## swap
                if (i%2==1):
                    context_list_choose = context_list_swap
                    scheduler_status_choose = scheduler_status_swap
                else:
                    context_list_choose = context_list
                    scheduler_status_choose = scheduler_status


                for j, context in enumerate(context_list_choose):
                    self.scheduler.__dict__.update(scheduler_status_choose[j])

                    latents_j = latents[context, :, :, :]

                    # Relevant thread:
                    # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
                    if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1:
                        torch._inductor.cudagraph_mark_step_begin()
                    # expand the latents if we are doing classifier free guidance
                    latent_model_input = torch.cat([latents_j] * 2) if self.do_classifier_free_guidance else latents_j
                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                    # brushnet(s) inference
                    if guess_mode and self.do_classifier_free_guidance:
                        # Infer BrushNet only for the conditional batch.
                        control_model_input = latents_j
                        control_model_input = self.scheduler.scale_model_input(control_model_input, t)
                        brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                        brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
                    else:
                        control_model_input = latent_model_input
                        brushnet_prompt_embeds = prompt_embeds
                        if self.do_classifier_free_guidance:
                            neg_brushnet_prompt_embeds, brushnet_prompt_embeds = brushnet_prompt_embeds.chunk(2)
                            brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
                            neg_brushnet_prompt_embeds = rearrange(repeat(neg_brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
                            brushnet_prompt_embeds = torch.cat([neg_brushnet_prompt_embeds, brushnet_prompt_embeds])
                        else:
                            brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')

                    if isinstance(brushnet_keep[i], list):
                        cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])]
                    else:
                        brushnet_cond_scale = brushnet_conditioning_scale
                        if isinstance(brushnet_cond_scale, list):
                            brushnet_cond_scale = brushnet_cond_scale[0]
                        cond_scale = brushnet_cond_scale * brushnet_keep[i]


                    down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet(
                        control_model_input,
                        t,
                        encoder_hidden_states=brushnet_prompt_embeds,
                        brushnet_cond=torch.cat([conditioning_latents[context, :, :, :]]*2) if self.do_classifier_free_guidance else conditioning_latents[context, :, :, :],
                        conditioning_scale=cond_scale,
                        guess_mode=guess_mode,
                        return_dict=False,
                    )

                    if guess_mode and self.do_classifier_free_guidance:
                        # Infered BrushNet only for the conditional batch.
                        # To apply the output of BrushNet to both the unconditional and conditional batches,
                        # add 0 to the unconditional batch to keep it unchanged.
                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
                        mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
                        up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples]

                    # predict the noise residual
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        timestep_cond=timestep_cond,
                        cross_attention_kwargs=self.cross_attention_kwargs,
                        down_block_add_samples=down_block_res_samples,
                        mid_block_add_sample=mid_block_res_sample,
                        up_block_add_samples=up_block_res_samples,
                        added_cond_kwargs=added_cond_kwargs,
                        return_dict=False,
                        num_frames=num_frames,
                    )[0]

                    # perform guidance
                    if self.do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    latents_j = self.scheduler.step(noise_pred, t, latents_j, **extra_step_kwargs, return_dict=False)[0]

                    count[context, ...] += 1

                    if j==0:
                        value[context, ...] += latents_j
                    else:
                        overlap_index_list = [index for index, value in enumerate(count[context, 0, 0, 0]) if value > 1]
                        overlap_cur = len(overlap_index_list)
                        ratio_next = torch.linspace(0, 1, overlap_cur+2)[1:-1]
                        ratio_pre = 1-ratio_next
                        for i_overlap in overlap_index_list:
                            value[context[i_overlap], ...] = value[context[i_overlap], ...]*ratio_pre[i_overlap] + latents_j[i_overlap, ...]*ratio_next[i_overlap]
                        value[context[i_overlap:num_frames], ...] = latents_j[i_overlap:num_frames, ...]

                latents = value.clone()

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)


        # If we do sequential model offloading, let's offload unet and brushnet
        # manually for max memory savings
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.unet.to("cpu")
            self.brushnet.to("cpu")
            torch.cuda.empty_cache()

        if  output_type == "latent":
            
            image = latents
            has_nsfw_concept = None
            return DiffuEraserPipelineOutput(frames=image, nsfw_content_detected=has_nsfw_concept)

        video_tensor = self.decode_latents(latents, weight_dtype=prompt_embeds.dtype)

        if output_type == "pt":
            video = video_tensor
        else:
            video = []
            for i in range(video_tensor.shape[0]):
                video.append(self.image_processor.postprocess(video_tensor[i:i+1], output_type=output_type)[0])

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (video, has_nsfw_concept)

        return DiffuEraserPipelineOutput(frames=video, latents=latents)


================================================
FILE: libs/transformer_temporal.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional

import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.resnet import AlphaBlender


@dataclass
class TransformerTemporalModelOutput(BaseOutput):
    """
    The output of [`TransformerTemporalModel`].

    Args:
        sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
            The hidden states output conditioned on `encoder_hidden_states` input.
    """

    sample: torch.FloatTensor


class TransformerTemporalModel(ModelMixin, ConfigMixin):
    """
    A Transformer model for video-like data.

    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            The number of channels in the input and output (specify if the input is **continuous**).
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        attention_bias (`bool`, *optional*):
            Configure if the `TransformerBlock` attention should contain a bias parameter.
        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
            This is fixed during training since it is used to learn a number of position embeddings.
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
            activation functions.
        norm_elementwise_affine (`bool`, *optional*):
            Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
        double_self_attention (`bool`, *optional*):
            Configure if each `TransformerBlock` should contain two self-attention layers.
        positional_embeddings: (`str`, *optional*):
            The type of positional embeddings to apply to the sequence input before passing use.
        num_positional_embeddings: (`int`, *optional*):
            The maximum length of the sequence over which to apply positional embeddings.
    """

    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        sample_size: Optional[int] = None,
        activation_fn: str = "geglu",
        norm_elementwise_affine: bool = True,
        double_self_attention: bool = True,
        positional_embeddings: Optional[str] = None,
        num_positional_embeddings: Optional[int] = None,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim

        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(in_channels, inner_dim)

        # 3. Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    double_self_attention=double_self_attention,
                    norm_elementwise_affine=norm_elementwise_affine,
                    positional_embeddings=positional_embeddings,
                    num_positional_embeddings=num_positional_embeddings,
                )
                for d in range(num_layers)
            ]
        )

        self.proj_out = nn.Linear(inner_dim, in_channels)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        timestep: Optional[torch.LongTensor] = None,
        num_frames: int = 1,
        encoder_hidden_states: Optional[torch.LongTensor] = None,
        class_labels: torch.LongTensor = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> TransformerTemporalModelOutput:
        """
        The [`TransformerTemporal`] forward method.

        Args:
            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
                Input hidden_states.
            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            timestep ( `torch.LongTensor`, *optional*):
                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
                `AdaLayerZeroNorm`.
            num_frames (`int`, *optional*, defaults to 1):
                The number of frames to be processed per batch. This is used to reshape the hidden states.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.

        Returns:
            [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
                If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
                returned, otherwise a `tuple` where the first element is the sample tensor.
        """
        # 1. Input
        batch_frames, channel, height, width = hidden_states.shape
        batch_size = batch_frames // num_frames

        residual = hidden_states

        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
        hidden_states = hidden_states.permute(0, 2, 1, 3, 4)

        hidden_states = self.norm(hidden_states)
        hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

        hidden_states = self.proj_in(hidden_states)

        # 2. Blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                timestep=timestep,
                cross_attention_kwargs=cross_attention_kwargs,
                class_labels=class_labels,
            )

        # 3. Output
        hidden_states = self.proj_out(hidden_states)
        hidden_states = (
            hidden_states[None, None, :]
            .reshape(batch_size, height, width, num_frames, channel)
            .permute(0, 3, 4, 1, 2)
            .contiguous()
        )
        hidden_states = hidden_states.reshape(batch_frames, channel, height, width)

        output = hidden_states + residual

        return output


class TransformerSpatioTemporalModel(nn.Module):
    """
    A Transformer model for video-like data.

    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            The number of channels in the input and output (specify if the input is **continuous**).
        out_channels (`int`, *optional*):
            The number of channels in the output (specify if the input is **continuous**).
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
    """

    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: int = 320,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        cross_attention_dim: Optional[int] = None,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim

        inner_dim = num_attention_heads * attention_head_dim
        self.inner_dim = inner_dim

        # 2. Define input layers
        self.in_channels = in_channels
        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
        self.proj_in = nn.Linear(in_channels, inner_dim)

        # 3. Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    cross_attention_dim=cross_attention_dim,
                )
                for d in range(num_layers)
            ]
        )

        time_mix_inner_dim = inner_dim
        self.temporal_transformer_blocks = nn.ModuleList(
            [
                TemporalBasicTransformerBlock(
                    inner_dim,
                    time_mix_inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    cross_attention_dim=cross_attention_dim,
                )
                for _ in range(num_layers)
            ]
        )

        time_embed_dim = in_channels * 4
        self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
        self.time_proj = Timesteps(in_channels, True, 0)
        self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")

        # 4. Define output layers
        self.out_channels = in_channels if out_channels is None else out_channels
        # TODO: should use out_channels for continuous projections
        self.proj_out = nn.Linear(inner_dim, in_channels)

        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        image_only_indicator: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ):
        """
        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
                Input hidden_states.
            num_frames (`int`):
                The number of frames to be processed per batch. This is used to reshape the hidden states.
            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
                A tensor indicating whether the input contains only images. 1 indicates that the input contains only
                images, 0 indicates that the input contains video frames.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
                tuple.

        Returns:
            [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
                If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
                returned, otherwise a `tuple` where the first element is the sample tensor.
        """
        # 1. Input
        batch_frames, _, height, width = hidden_states.shape
        num_frames = image_only_indicator.shape[-1]
        batch_size = batch_frames // num_frames

        time_context = encoder_hidden_states
        time_context_first_timestep = time_context[None, :].reshape(
            batch_size, num_frames, -1, time_context.shape[-1]
        )[:, 0]
        time_context = time_context_first_timestep[None, :].broadcast_to(
            height * width, batch_size, 1, time_context.shape[-1]
        )
        time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])

        residual = hidden_states

        hidden_states = self.norm(hidden_states)
        inner_dim = hidden_states.shape[1]
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
        hidden_states = self.proj_in(hidden_states)

        num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
        num_frames_emb = num_frames_emb.repeat(batch_size, 1)
        num_frames_emb = num_frames_emb.reshape(-1)
        t_emb = self.time_proj(num_frames_emb)

        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=hidden_states.dtype)

        emb = self.time_pos_embed(t_emb)
        emb = emb[:, None, :]

        # 2. Blocks
        for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
            if self.training and self.gradient_checkpointing:
                hidden_states = torch.utils.checkpoint.checkpoint(
                    block,
                    hidden_states,
                    None,
                    encoder_hidden_states,
                    None,
                    use_reentrant=False,
                )
            else:
                hidden_states = block(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                )

            hidden_states_mix = hidden_states
            hidden_states_mix = hidden_states_mix + emb

            hidden_states_mix = temporal_block(
                hidden_states_mix,
                num_frames=num_frames,
                encoder_hidden_states=time_context,
            )
            hidden_states = self.time_mixer(
                x_spatial=hidden_states,
                x_temporal=hidden_states_mix,
                image_only_indicator=image_only_indicator,
            )

        # 3. Output
        hidden_states = self.proj_out(hidden_states)
        hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()

        output = hidden_states + residual

        if not return_dict:
            return (output,)

        return TransformerTemporalModelOutput(sample=output)


================================================
FILE: libs/unet_2d_blocks.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import apply_freeu
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.resnet import (
    Downsample2D,
    FirDownsample2D,
    FirUpsample2D,
    KDownsample2D,
    KUpsample2D,
    ResnetBlock2D,
    ResnetBlockCondNorm2D,
    Upsample2D,
)
from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
from diffusers.models.transformers.transformer_2d import Transformer2DModel


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def get_down_block(
    down_block_type: str,
    num_layers: int,
    in_channels: int,
    out_channels: int,
    temb_channels: int,
    add_downsample: bool,
    resnet_eps: float,
    resnet_act_fn: str,
    transformer_layers_per_block: int = 1,
    num_attention_heads: Optional[int] = None,
    resnet_groups: Optional[int] = None,
    cross_attention_dim: Optional[int] = None,
    downsample_padding: Optional[int] = None,
    dual_cross_attention: bool = False,
    use_linear_projection: bool = False,
    only_cross_attention: bool = False,
    upcast_attention: bool = False,
    resnet_time_scale_shift: str = "default",
    attention_type: str = "default",
    resnet_skip_time_act: bool = False,
    resnet_out_scale_factor: float = 1.0,
    cross_attention_norm: Optional[str] = None,
    attention_head_dim: Optional[int] = None,
    downsample_type: Optional[str] = None,
    dropout: float = 0.0,
):
    # If attn head dim is not defined, we default it to the number of heads
    if attention_head_dim is None:
        logger.warn(
            f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
        )
        attention_head_dim = num_attention_heads

    down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
    if down_block_type == "DownBlock2D":
        return DownBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            downsample_padding=downsample_padding,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    elif down_block_type == "ResnetDownsampleBlock2D":
        return ResnetDownsampleBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            resnet_time_scale_shift=resnet_time_scale_shift,
            skip_time_act=resnet_skip_time_act,
            output_scale_factor=resnet_out_scale_factor,
        )
    elif down_block_type == "AttnDownBlock2D":
        if add_downsample is False:
            downsample_type = None
        else:
            downsample_type = downsample_type or "conv"  # default to 'conv'
        return AttnDownBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            downsample_padding=downsample_padding,
            attention_head_dim=attention_head_dim,
            resnet_time_scale_shift=resnet_time_scale_shift,
            downsample_type=downsample_type,
        )
    elif down_block_type == "CrossAttnDownBlock2D":
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
        return CrossAttnDownBlock2D(
            num_layers=num_layers,
            transformer_layers_per_block=transformer_layers_per_block,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            downsample_padding=downsample_padding,
            cross_attention_dim=cross_attention_dim,
            num_attention_heads=num_attention_heads,
            dual_cross_attention=dual_cross_attention,
            use_linear_projection=use_linear_projection,
            only_cross_attention=only_cross_attention,
            upcast_attention=upcast_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            attention_type=attention_type,
        )
    elif down_block_type == "SimpleCrossAttnDownBlock2D":
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
        return SimpleCrossAttnDownBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            cross_attention_dim=cross_attention_dim,
            attention_head_dim=attention_head_dim,
            resnet_time_scale_shift=resnet_time_scale_shift,
            skip_time_act=resnet_skip_time_act,
            output_scale_factor=resnet_out_scale_factor,
            only_cross_attention=only_cross_attention,
            cross_attention_norm=cross_attention_norm,
        )
    elif down_block_type == "SkipDownBlock2D":
        return SkipDownBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            downsample_padding=downsample_padding,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    elif down_block_type == "AttnSkipDownBlock2D":
        return AttnSkipDownBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            attention_head_dim=attention_head_dim,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    elif down_block_type == "DownEncoderBlock2D":
        return DownEncoderBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            downsample_padding=downsample_padding,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    elif down_block_type == "AttnDownEncoderBlock2D":
        return AttnDownEncoderBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            dropout=dropout,
            add_downsample=add_downsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            downsample_padding=downsample_padding,
            attention_head_dim=attention_head_dim,
            resnet_time_scale_shift=resnet_time_scale_shift,
        )
    elif down_block_type == "KDownBlock2D":
        return KDownBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            temb_channels=temb_c
Download .txt
gitextract_svppjk7y/

├── LICENSE
├── README.md
├── __init__.py
├── diffueraser_node.py
├── example_workflows/
│   └── differaser.json
├── libs/
│   ├── __init__.py
│   ├── brushnet_CA.py
│   ├── diffueraser.py
│   ├── pipeline_diffueraser.py
│   ├── transformer_temporal.py
│   ├── unet_2d_blocks.py
│   ├── unet_2d_condition.py
│   ├── unet_3d_blocks.py
│   ├── unet_motion_model.py
│   └── v1-inference.yaml
├── node_utils.py
├── propainter/
│   ├── RAFT/
│   │   ├── __init__.py
│   │   ├── corr.py
│   │   ├── datasets.py
│   │   ├── demo.py
│   │   ├── extractor.py
│   │   ├── raft.py
│   │   ├── update.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── augmentor.py
│   │       ├── flow_viz.py
│   │       ├── flow_viz_pt.py
│   │       ├── frame_utils.py
│   │       └── utils.py
│   ├── core/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   ├── dist.py
│   │   ├── loss.py
│   │   ├── lr_scheduler.py
│   │   ├── metrics.py
│   │   ├── prefetch_dataloader.py
│   │   ├── trainer.py
│   │   ├── trainer_flow_w_edge.py
│   │   └── utils.py
│   ├── inference.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── canny/
│   │   │   ├── __init__.py
│   │   │   ├── canny_filter.py
│   │   │   ├── filter.py
│   │   │   ├── gaussian.py
│   │   │   ├── kernels.py
│   │   │   └── sobel.py
│   │   ├── misc.py
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   ├── base_module.py
│   │   │   ├── deformconv.py
│   │   │   ├── flow_comp_raft.py
│   │   │   ├── flow_loss_utils.py
│   │   │   ├── sparse_transformer.py
│   │   │   └── spectral_norm.py
│   │   ├── propainter.py
│   │   ├── recurrent_flow_completion.py
│   │   └── vgg_arch.py
│   └── utils/
│       ├── __init__.py
│       ├── download_util.py
│       ├── file_client.py
│       ├── flow_util.py
│       └── img_util.py
├── pyproject.toml
├── requirements.txt
├── run_diffueraser.py
└── sd15_repo/
    ├── feature_extractor/
    │   └── preprocessor_config.json
    ├── model_index.json
    ├── safety_checker/
    │   └── config.json
    ├── scheduler/
    │   └── scheduler_config.json
    ├── text_encoder/
    │   └── config.json
    ├── tokenizer/
    │   ├── merges.txt
    │   ├── special_tokens_map.json
    │   ├── tokenizer_config.json
    │   └── vocab.json
    ├── unet/
    │   └── config.json
    └── vae/
        └── config.json
Download .txt
SYMBOL INDEX (761 symbols across 51 files)

FILE: diffueraser_node.py
  class Propainter_Loader (line 29) | class Propainter_Loader(io.ComfyNode):
    method define_schema (line 31) | def define_schema(cls):
    method execute (line 47) | def execute(cls, propainter,flow,fix_raft,device) -> io.NodeOutput:
  class DiffuEraser_Loader (line 57) | class DiffuEraser_Loader(io.ComfyNode):
    method define_schema (line 59) | def define_schema(cls):
    method execute (line 73) | def execute(cls, vae,lora) -> io.NodeOutput:
  class DiffuEraser_PreData (line 83) | class DiffuEraser_PreData(io.ComfyNode):
    method define_schema (line 85) | def define_schema(cls):
    method execute (line 101) | def execute(cls, images,seg_repo,video_mask_image=None,video_mask=None...
  class Propainter_Sampler (line 137) | class Propainter_Sampler(io.ComfyNode):
    method define_schema (line 139) | def define_schema(cls):
    method execute (line 161) | def execute(cls, model,conditioning,fps,video_length,mask_dilation_ite...
  class DiffuEraser_Sampler (line 179) | class DiffuEraser_Sampler(io.ComfyNode):
    method define_schema (line 181) | def define_schema(cls):
    method execute (line 203) | def execute(cls, model,positive,conditioning,steps,seed,save_result_vi...
  function get_hello (line 236) | async def get_hello(request):
  class DiffuEraser_SM_Extension (line 239) | class DiffuEraser_SM_Extension(ComfyExtension):
    method get_node_list (line 241) | async def get_node_list(self) -> list[type[io.ComfyNode]]:
  function comfy_entrypoint (line 249) | async def comfy_entrypoint() -> DiffuEraser_SM_Extension:  # ComfyUI cal...

FILE: libs/brushnet_CA.py
  class BrushNetOutput (line 38) | class BrushNetOutput(BaseOutput):
  class BrushNetModel (line 62) | class BrushNetModel(ModelMixin, ConfigMixin):
    method __init__ (line 139) | def __init__(
    method from_unet (line 454) | def from_unet(
    method attn_processors (line 543) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 567) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
    method set_default_attn_processor (line 602) | def set_default_attn_processor(self):
    method set_attention_slice (line 618) | def set_attention_slice(self, slice_size: Union[str, int, List[int]]) ...
    method _set_gradient_checkpointing (line 683) | def _set_gradient_checkpointing(self, module, value: bool = False) -> ...
    method forward (line 687) | def forward(
  function zero_module (line 936) | def zero_module(module):

FILE: libs/diffueraser.py
  function extract_step_number (line 30) | def extract_step_number(ckpt_name):
  function import_model_class_from_model_name_or_path (line 53) | def import_model_class_from_model_name_or_path(pretrained_model_name_or_...
  function resize_frames (line 76) | def resize_frames(frames, size=None):
  function read_mask (line 89) | def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilat...
  function read_priori (line 132) | def read_priori(priori, fps, n_total_frames, img_size):
  function read_video (line 163) | def read_video(validation_image, video_length, nframes, max_img_size):
  class DiffuEraser (line 195) | class DiffuEraser:
    method __init__ (line 196) | def __init__(self, device, ):
    method load_model (line 199) | def load_model(self,repo, diffueraser_path, ckpt_path,original_config_...
    method to (line 281) | def to(self, device):
    method forward (line 285) | def forward(self, validation_image, validation_mask, prioris, output_p...
  function simple_flicker_smoothing (line 512) | def simple_flicker_smoothing(frames, alpha=0.1):
  function morphological_edge_blur (line 536) | def morphological_edge_blur(mask, sigma=3.0, edge_width=5):

FILE: libs/pipeline_diffueraser.py
  function retrieve_timesteps (line 43) | def retrieve_timesteps(
  function get_frames_context_swap (line 86) | def get_frames_context_swap(total_frames=192, overlap=4, num_frames_per_...
  class DiffuEraserPipelineOutput (line 111) | class DiffuEraserPipelineOutput(BaseOutput):
  class StableDiffusionDiffuEraserPipeline (line 115) | class StableDiffusionDiffuEraserPipeline(
    method __init__ (line 163) | def __init__(
    method _encode_prompt (line 210) | def _encode_prompt(
    method encode_prompt (line 243) | def encode_prompt(
    method encode_image (line 425) | def encode_image(self, image, device, num_images_per_prompt, output_hi...
    method decode_latents (line 449) | def decode_latents(self, latents, weight_dtype):
    method prepare_ip_adapter_image_embeds (line 461) | def prepare_ip_adapter_image_embeds(
    method run_safety_checker (line 513) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 528) | def decode_latents(self, latents, weight_dtype):
    method prepare_extra_step_kwargs (line 542) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 559) | def check_inputs(
    method check_image (line 674) | def check_image(self, images, masks, prompt, prompt_embeds):
    method prepare_image (line 731) | def prepare_image(
    method prepare_latents (line 765) | def prepare_latents(self, batch_size, num_channels_latents, num_frames...
    method temp_blend (line 792) | def temp_blend(a, b, overlap):
    method get_guidance_scale_embedding (line 799) | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=tor...
    method guidance_scale (line 828) | def guidance_scale(self):
    method clip_skip (line 832) | def clip_skip(self):
    method do_classifier_free_guidance (line 839) | def do_classifier_free_guidance(self):
    method cross_attention_kwargs (line 843) | def cross_attention_kwargs(self):
    method num_timesteps (line 847) | def num_timesteps(self):
    method __call__ (line 852) | def __call__(

FILE: libs/transformer_temporal.py
  class TransformerTemporalModelOutput (line 29) | class TransformerTemporalModelOutput(BaseOutput):
  class TransformerTemporalModel (line 41) | class TransformerTemporalModel(ModelMixin, ConfigMixin):
    method __init__ (line 71) | def __init__(
    method forward (line 121) | def forward(
  class TransformerSpatioTemporalModel (line 198) | class TransformerSpatioTemporalModel(nn.Module):
    method __init__ (line 213) | def __init__(
    method forward (line 273) | def forward(

FILE: libs/unet_2d_blocks.py
  function get_down_block (line 43) | def get_down_block(
  function get_mid_block (line 252) | def get_mid_block(
  function get_up_block (line 351) | def get_up_block(
  class AutoencoderTinyBlock (line 576) | class AutoencoderTinyBlock(nn.Module):
    method __init__ (line 592) | def __init__(self, in_channels: int, out_channels: int, act_fn: str):
    method forward (line 609) | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
  class UNetMidBlock2D (line 613) | class UNetMidBlock2D(nn.Module):
    method __init__ (line 644) | def __init__(
    method forward (line 758) | def forward(self, hidden_states: torch.FloatTensor, temb: Optional[tor...
  class UNetMidBlock2DCrossAttn (line 768) | class UNetMidBlock2DCrossAttn(nn.Module):
    method __init__ (line 769) | def __init__(
    method forward (line 862) | def forward(
  class UNetMidBlock2DSimpleCrossAttn (line 914) | class UNetMidBlock2DSimpleCrossAttn(nn.Module):
    method __init__ (line 915) | def __init__(
    method forward (line 999) | def forward(
  class MidBlock2D (line 1038) | class MidBlock2D(nn.Module):
    method __init__ (line 1039) | def __init__(
    method forward (line 1094) | def forward(
  class AttnDownBlock2D (line 1126) | class AttnDownBlock2D(nn.Module):
    method __init__ (line 1127) | def __init__(
    method forward (line 1218) | def forward(
  class CrossAttnDownBlock2D (line 1249) | class CrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1250) | def __init__(
    method forward (line 1341) | def forward(
  class DownBlock2D (line 1417) | class DownBlock2D(nn.Module):
    method __init__ (line 1418) | def __init__(
    method forward (line 1469) | def forward(
  class DownEncoderBlock2D (line 1512) | class DownEncoderBlock2D(nn.Module):
    method __init__ (line 1513) | def __init__(
    method forward (line 1576) | def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0...
  class AttnDownEncoderBlock2D (line 1587) | class AttnDownEncoderBlock2D(nn.Module):
    method __init__ (line 1588) | def __init__(
    method forward (line 1674) | def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0...
  class AttnSkipDownBlock2D (line 1687) | class AttnSkipDownBlock2D(nn.Module):
    method __init__ (line 1688) | def __init__(
    method forward (line 1768) | def forward(
  class SkipDownBlock2D (line 1795) | class SkipDownBlock2D(nn.Module):
    method __init__ (line 1796) | def __init__(
    method forward (line 1855) | def forward(
  class ResnetDownsampleBlock2D (line 1880) | class ResnetDownsampleBlock2D(nn.Module):
    method __init__ (line 1881) | def __init__(
    method forward (line 1944) | def forward(
  class SimpleCrossAttnDownBlock2D (line 1980) | class SimpleCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1981) | def __init__(
    method forward (line 2075) | def forward(
  class KDownBlock2D (line 2140) | class KDownBlock2D(nn.Module):
    method __init__ (line 2141) | def __init__(
    method forward (line 2186) | def forward(
  class KCrossAttnDownBlock2D (line 2220) | class KCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 2221) | def __init__(
    method forward (line 2285) | def forward(
  class AttnUpBlock2D (line 2347) | class AttnUpBlock2D(nn.Module):
    method __init__ (line 2348) | def __init__(
    method forward (line 2439) | def forward(
  class CrossAttnUpBlock2D (line 2467) | class CrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2468) | def __init__(
    method forward (line 2558) | def forward(
  class UpBlock2D (line 2654) | class UpBlock2D(nn.Module):
    method __init__ (line 2655) | def __init__(
    method forward (line 2704) | def forward(
  class UpDecoderBlock2D (line 2782) | class UpDecoderBlock2D(nn.Module):
    method __init__ (line 2783) | def __init__(
    method forward (line 2844) | def forward(
  class AttnUpDecoderBlock2D (line 2857) | class AttnUpDecoderBlock2D(nn.Module):
    method __init__ (line 2858) | def __init__(
    method forward (line 2944) | def forward(
  class AttnSkipUpBlock2D (line 2959) | class AttnSkipUpBlock2D(nn.Module):
    method __init__ (line 2960) | def __init__(
    method forward (line 3053) | def forward(
  class SkipUpBlock2D (line 3089) | class SkipUpBlock2D(nn.Module):
    method __init__ (line 3090) | def __init__(
    method forward (line 3161) | def forward(
  class ResnetUpsampleBlock2D (line 3194) | class ResnetUpsampleBlock2D(nn.Module):
    method __init__ (line 3195) | def __init__(
    method forward (line 3263) | def forward(
  class SimpleCrossAttnUpBlock2D (line 3303) | class SimpleCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 3304) | def __init__(
    method forward (line 3402) | def forward(
  class KUpBlock2D (line 3469) | class KUpBlock2D(nn.Module):
    method __init__ (line 3470) | def __init__(
    method forward (line 3519) | def forward(
  class KCrossAttnUpBlock2D (line 3558) | class KCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 3559) | def __init__(
    method forward (line 3644) | def forward(
  class KAttentionBlock (line 3706) | class KAttentionBlock(nn.Module):
    method __init__ (line 3730) | def __init__(
    method _to_3d (line 3773) | def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight...
    method _to_4d (line 3776) | def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight...
    method forward (line 3779) | def forward(

FILE: libs/unet_2d_condition.py
  class UNet2DConditionOutput (line 57) | class UNet2DConditionOutput(BaseOutput):
  class UNet2DConditionModel (line 69) | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
    method __init__ (line 166) | def __init__(
    method _check_config (line 482) | def _check_config(
    method _set_time_proj (line 534) | def _set_time_proj(
    method _set_encoder_hid_proj (line 562) | def _set_encoder_hid_proj(
    method _set_class_embedding (line 602) | def _set_class_embedding(
    method _set_add_embedding (line 639) | def _set_add_embedding(
    method _set_pos_net_if_use_gligen (line 679) | def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attent...
    method attn_processors (line 693) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 716) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
    method set_default_attn_processor (line 750) | def set_default_attn_processor(self):
    method set_attention_slice (line 765) | def set_attention_slice(self, slice_size: Union[str, int, List[int]] =...
    method _set_gradient_checkpointing (line 830) | def _set_gradient_checkpointing(self, module, value=False):
    method enable_freeu (line 834) | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
    method disable_freeu (line 858) | def disable_freeu(self):
    method fuse_qkv_projections (line 866) | def fuse_qkv_projections(self):
    method unfuse_qkv_projections (line 889) | def unfuse_qkv_projections(self):
    method unload_lora (line 902) | def unload_lora(self):
    method get_time_embed (line 913) | def get_time_embed(
    method get_class_embed (line 939) | def get_class_embed(self, sample: torch.Tensor, class_labels: Optional...
    method get_aug_embed (line 955) | def get_aug_embed(
    method process_encoder_hidden_states (line 1007) | def process_encoder_hidden_states(
    method forward (line 1039) | def forward(

FILE: libs/unet_3d_blocks.py
  function get_down_block (line 37) | def get_down_block(
  function get_up_block (line 165) | def get_up_block(
  class UNetMidBlock3DCrossAttn (line 304) | class UNetMidBlock3DCrossAttn(nn.Module):
    method __init__ (line 305) | def __init__(
    method forward (line 406) | def forward(
  class CrossAttnDownBlock3D (line 438) | class CrossAttnDownBlock3D(nn.Module):
    method __init__ (line 439) | def __init__(
    method forward (line 539) | def forward(
  class DownBlock3D (line 580) | class DownBlock3D(nn.Module):
    method __init__ (line 581) | def __init__(
    method forward (line 646) | def forward(
  class CrossAttnUpBlock3D (line 669) | class CrossAttnUpBlock3D(nn.Module):
    method __init__ (line 670) | def __init__(
    method forward (line 764) | def forward(
  class UpBlock3D (line 826) | class UpBlock3D(nn.Module):
    method __init__ (line 827) | def __init__(
    method forward (line 886) | def forward(
  class DownBlockMotion (line 929) | class DownBlockMotion(nn.Module):
    method __init__ (line 930) | def __init__(
    method forward (line 1003) | def forward(
  class CrossAttnDownBlockMotion (line 1066) | class CrossAttnDownBlockMotion(nn.Module):
    method __init__ (line 1067) | def __init__(
    method forward (line 1181) | def forward(
  class CrossAttnUpBlockMotion (line 1275) | class CrossAttnUpBlockMotion(nn.Module):
    method __init__ (line 1276) | def __init__(
    method forward (line 1383) | def forward(
  class UpBlockMotion (line 1485) | class UpBlockMotion(nn.Module):
    method __init__ (line 1486) | def __init__(
    method forward (line 1555) | def forward(
  class UNetMidBlockCrossAttnMotion (line 1639) | class UNetMidBlockCrossAttnMotion(nn.Module):
    method __init__ (line 1640) | def __init__(
    method forward (line 1747) | def forward(
  class MidBlockTemporalDecoder (line 1824) | class MidBlockTemporalDecoder(nn.Module):
    method __init__ (line 1825) | def __init__(
    method forward (line 1868) | def forward(
  class UpBlockTemporalDecoder (line 1887) | class UpBlockTemporalDecoder(nn.Module):
    method __init__ (line 1888) | def __init__(
    method forward (line 1919) | def forward(
  class UNetMidBlockSpatioTemporal (line 1937) | class UNetMidBlockSpatioTemporal(nn.Module):
    method __init__ (line 1938) | def __init__(
    method forward (line 1992) | def forward(
  class DownBlockSpatioTemporal (line 2047) | class DownBlockSpatioTemporal(nn.Module):
    method __init__ (line 2048) | def __init__(
    method forward (line 2088) | def forward(
  class CrossAttnDownBlockSpatioTemporal (line 2137) | class CrossAttnDownBlockSpatioTemporal(nn.Module):
    method __init__ (line 2138) | def __init__(
    method forward (line 2198) | def forward(
  class UpBlockSpatioTemporal (line 2259) | class UpBlockSpatioTemporal(nn.Module):
    method __init__ (line 2260) | def __init__(
    method forward (line 2297) | def forward(
  class CrossAttnUpBlockSpatioTemporal (line 2348) | class CrossAttnUpBlockSpatioTemporal(nn.Module):
    method __init__ (line 2349) | def __init__(
    method forward (line 2406) | def forward(

FILE: libs/unet_motion_model.py
  class MotionModules (line 50) | class MotionModules(nn.Module):
    method __init__ (line 51) | def __init__(
  class MotionAdapter (line 81) | class MotionAdapter(ModelMixin, ConfigMixin):
    method __init__ (line 83) | def __init__(
    method forward (line 165) | def forward(self, sample):
  class UNetMotionModel (line 169) | class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMix...
    method __init__ (line 181) | def __init__(
    method from_unet2d (line 389) | def from_unet2d(
    method freeze_unet2d_params (line 468) | def freeze_unet2d_params(self) -> None:
    method load_motion_modules (line 492) | def load_motion_modules(self, motion_adapter: Optional[MotionAdapter])...
    method save_motion_modules (line 502) | def save_motion_modules(
    method attn_processors (line 539) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 563) | def set_attn_processor(
    method enable_forward_chunking (line 600) | def enable_forward_chunking(self, chunk_size: Optional[int] = None, di...
    method disable_forward_chunking (line 630) | def disable_forward_chunking(self) -> None:
    method set_default_attn_processor (line 642) | def set_default_attn_processor(self) -> None:
    method _set_gradient_checkpointing (line 657) | def _set_gradient_checkpointing(self, module, value: bool = False) -> ...
    method enable_freeu (line 662) | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> ...
    method disable_freeu (line 687) | def disable_freeu(self) -> None:
    method forward (line 695) | def forward(

FILE: node_utils.py
  function image2masks (line 20) | def image2masks(repo,video_image):
  function resize_and_center_paste (line 53) | def resize_and_center_paste(image_list, target_size=(1024, 1024)):
  function center_paste_and_resize (line 93) | def center_paste_and_resize(image_list, target_size=(1024, 1024)):
  function tensor_to_pil (line 120) | def tensor_to_pil(tensor):
  function tensor2pil_list (line 125) | def tensor2pil_list(image,width,height):
  function tensor2pil_upscale (line 134) | def tensor2pil_upscale(img_tensor, width, height):
  function nomarl_upscale (line 141) | def nomarl_upscale(img, width, height):
  function tensor2cv (line 149) | def tensor2cv(tensor_image):
  function cvargb2tensor (line 162) | def cvargb2tensor(img):
  function cv2tensor (line 167) | def cv2tensor(img):
  function images_generator (line 173) | def images_generator(img_list: list,):
  function load_images (line 224) | def load_images(img_list: list,):
  function tensor2pil (line 232) | def tensor2pil(tensor):
  function pil2narry (line 237) | def pil2narry(img):
  function equalize_lists (line 241) | def equalize_lists(list1, list2):
  function file_exists (line 268) | def file_exists(directory, filename):
  function download_weights (line 274) | def download_weights(file_dir,repo_id,subfolder="",pt_name=""):

FILE: propainter/RAFT/corr.py
  class CorrBlock (line 12) | class CorrBlock:
    method __init__ (line 13) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 29) | def __call__(self, coords):
    method corr (line 53) | def corr(fmap1, fmap2):
  class CorrLayer (line 63) | class CorrLayer(torch.autograd.Function):
    method forward (line 65) | def forward(ctx, fmap1, fmap2, coords, r):
    method backward (line 75) | def backward(ctx, grad_corr):
  class AlternateCorrBlock (line 83) | class AlternateCorrBlock:
    method __init__ (line 84) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 94) | def __call__(self, coords):

FILE: propainter/RAFT/datasets.py
  class FlowDataset (line 18) | class FlowDataset(data.Dataset):
    method __init__ (line 19) | def __init__(self, aug_params=None, sparse=False):
    method __getitem__ (line 34) | def __getitem__(self, index):
    method __rmul__ (line 93) | def __rmul__(self, v):
    method __len__ (line 98) | def __len__(self):
  class MpiSintel (line 102) | class MpiSintel(FlowDataset):
    method __init__ (line 103) | def __init__(self, aug_params=None, split='training', root='datasets/S...
  class FlyingChairs (line 121) | class FlyingChairs(FlowDataset):
    method __init__ (line 122) | def __init__(self, aug_params=None, split='train', root='datasets/Flyi...
  class FlyingThings3D (line 137) | class FlyingThings3D(FlowDataset):
    method __init__ (line 138) | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', ds...
  class KITTI (line 161) | class KITTI(FlowDataset):
    method __init__ (line 162) | def __init__(self, aug_params=None, split='training', root='datasets/K...
  class HD1K (line 180) | class HD1K(FlowDataset):
    method __init__ (line 181) | def __init__(self, aug_params=None, root='datasets/HD1k'):
  function fetch_dataloader (line 199) | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):

FILE: propainter/RAFT/demo.py
  function load_image (line 18) | def load_image(imfile):
  function load_image_list (line 24) | def load_image_list(image_files):
  function viz (line 36) | def viz(img, flo):
  function demo (line 50) | def demo(args):
  function RAFT_infer (line 71) | def RAFT_infer(args):

FILE: propainter/RAFT/extractor.py
  class ResidualBlock (line 6) | class ResidualBlock(nn.Module):
    method __init__ (line 7) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 48) | def forward(self, x):
  class BottleneckBlock (line 60) | class BottleneckBlock(nn.Module):
    method __init__ (line 61) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 107) | def forward(self, x):
  class BasicEncoder (line 118) | class BasicEncoder(nn.Module):
    method __init__ (line 119) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 159) | def _make_layer(self, dim, stride=1):
    method forward (line 168) | def forward(self, x):
  class SmallEncoder (line 195) | class SmallEncoder(nn.Module):
    method __init__ (line 196) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 235) | def _make_layer(self, dim, stride=1):
    method forward (line 244) | def forward(self, x):

FILE: propainter/RAFT/raft.py
  class autocast (line 15) | class autocast:
    method __init__ (line 16) | def __init__(self, enabled):
    method __enter__ (line 18) | def __enter__(self):
    method __exit__ (line 20) | def __exit__(self, *args):
  class RAFT (line 24) | class RAFT(nn.Module):
    method __init__ (line 25) | def __init__(self, args):
    method freeze_bn (line 59) | def freeze_bn(self):
    method initialize_flow (line 64) | def initialize_flow(self, img):
    method upsample_flow (line 73) | def upsample_flow(self, flow, mask):
    method forward (line 87) | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=...

FILE: propainter/RAFT/update.py
  class FlowHead (line 6) | class FlowHead(nn.Module):
    method __init__ (line 7) | def __init__(self, input_dim=128, hidden_dim=256):
    method forward (line 13) | def forward(self, x):
  class ConvGRU (line 16) | class ConvGRU(nn.Module):
    method __init__ (line 17) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 23) | def forward(self, h, x):
  class SepConvGRU (line 33) | class SepConvGRU(nn.Module):
    method __init__ (line 34) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 45) | def forward(self, h, x):
  class SmallMotionEncoder (line 62) | class SmallMotionEncoder(nn.Module):
    method __init__ (line 63) | def __init__(self, args):
    method forward (line 71) | def forward(self, flow, corr):
  class BasicMotionEncoder (line 79) | class BasicMotionEncoder(nn.Module):
    method __init__ (line 80) | def __init__(self, args):
    method forward (line 89) | def forward(self, flow, corr):
  class SmallUpdateBlock (line 99) | class SmallUpdateBlock(nn.Module):
    method __init__ (line 100) | def __init__(self, args, hidden_dim=96):
    method forward (line 106) | def forward(self, net, inp, corr, flow):
  class BasicUpdateBlock (line 114) | class BasicUpdateBlock(nn.Module):
    method __init__ (line 115) | def __init__(self, args, hidden_dim=128, input_dim=128):
    method forward (line 127) | def forward(self, net, inp, corr, flow, upsample=True):

FILE: propainter/RAFT/utils/augmentor.py
  class FlowAugmentor (line 15) | class FlowAugmentor:
    method __init__ (line 16) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=T...
    method color_transform (line 36) | def color_transform(self, img1, img2):
    method eraser_transform (line 52) | def eraser_transform(self, img1, img2, bounds=[50, 100]):
    method spatial_transform (line 67) | def spatial_transform(self, img1, img2, flow):
    method __call__ (line 111) | def __call__(self, img1, img2, flow):
  class SparseFlowAugmentor (line 122) | class SparseFlowAugmentor:
    method __init__ (line 123) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=F...
    method color_transform (line 142) | def color_transform(self, img1, img2):
    method eraser_transform (line 148) | def eraser_transform(self, img1, img2):
    method resize_sparse_flow_map (line 161) | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
    method spatial_transform (line 195) | def spatial_transform(self, img1, img2, flow, valid):
    method __call__ (line 236) | def __call__(self, img1, img2, flow, valid):

FILE: propainter/RAFT/utils/flow_viz.py
  function make_colorwheel (line 20) | def make_colorwheel():
  function flow_uv_to_colors (line 70) | def flow_uv_to_colors(u, v, convert_to_bgr=False):
  function flow_to_image (line 109) | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):

FILE: propainter/RAFT/utils/flow_viz_pt.py
  function flow_to_image (line 6) | def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
  function _normalized_flow_to_image (line 39) | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Te...
  function _make_colorwheel (line 74) | def _make_colorwheel() -> torch.Tensor:

FILE: propainter/RAFT/utils/frame_utils.py
  function readFlow (line 12) | def readFlow(fn):
  function readPFM (line 33) | def readPFM(file):
  function writeFlow (line 70) | def writeFlow(filename,uv,v=None):
  function readFlowKITTI (line 102) | def readFlowKITTI(filename):
  function readDispKITTI (line 109) | def readDispKITTI(filename):
  function writeFlowKITTI (line 116) | def writeFlowKITTI(filename, uv):
  function read_gen (line 123) | def read_gen(file_name, pil=False):

FILE: propainter/RAFT/utils/utils.py
  class InputPadder (line 7) | class InputPadder:
    method __init__ (line 9) | def __init__(self, dims, mode='sintel'):
    method pad (line 18) | def pad(self, *inputs):
    method unpad (line 21) | def unpad(self,x):
  function forward_interpolate (line 26) | def forward_interpolate(flow):
  function bilinear_sampler (line 57) | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
  function coords_grid (line 74) | def coords_grid(batch, ht, wd):
  function upflow8 (line 80) | def upflow8(flow, mode='bilinear'):

FILE: propainter/core/dataset.py
  class TrainDataset (line 19) | class TrainDataset(torch.utils.data.Dataset):
    method __init__ (line 20) | def __init__(self, args: dict):
    method __len__ (line 58) | def __len__(self):
    method _sample_index (line 61) | def _sample_index(self, length, sample_length, num_ref_frame=3):
    method __getitem__ (line 70) | def __getitem__(self, index):
  class TestDataset (line 141) | class TestDataset(torch.utils.data.Dataset):
    method __init__ (line 142) | def __init__(self, args):
    method __len__ (line 170) | def __len__(self):
    method __getitem__ (line 173) | def __getitem__(self, index):

FILE: propainter/core/dist.py
  function get_world_size (line 5) | def get_world_size():
  function get_global_rank (line 17) | def get_global_rank():
  function get_local_rank (line 29) | def get_local_rank():
  function get_master_ip (line 41) | def get_master_ip():

FILE: propainter/core/loss.py
  class PerceptualLoss (line 6) | class PerceptualLoss(nn.Module):
    method __init__ (line 29) | def __init__(self,
    method forward (line 59) | def forward(self, x, gt):
    method _gram_mat (line 101) | def _gram_mat(self, x):
  class LPIPSLoss (line 116) | class LPIPSLoss(nn.Module):
    method __init__ (line 117) | def __init__(self,
    method forward (line 133) | def forward(self, pred, target):
  class AdversarialLoss (line 144) | class AdversarialLoss(nn.Module):
    method __init__ (line 149) | def __init__(self,
    method __call__ (line 168) | def __call__(self, outputs, is_real, is_disc=None):

FILE: propainter/core/lr_scheduler.py
  class MultiStepRestartLR (line 9) | class MultiStepRestartLR(_LRScheduler):
    method __init__ (line 20) | def __init__(self,
    method get_lr (line 35) | def get_lr(self):
  function get_position_from_periods (line 50) | def get_position_from_periods(iteration, cumulative_period):
  class CosineAnnealingRestartLR (line 68) | class CosineAnnealingRestartLR(_LRScheduler):
    method __init__ (line 84) | def __init__(self,
    method get_lr (line 100) | def get_lr(self):

FILE: propainter/core/metrics.py
  function calculate_epe (line 13) | def calculate_epe(flow1, flow2):
  function calculate_psnr (line 21) | def calculate_psnr(img1, img2):
  function calc_psnr_and_ssim (line 40) | def calc_psnr_and_ssim(img1, img2):
  function init_i3d_model (line 64) | def init_i3d_model(i3d_model_path):
  function calculate_i3d_activations (line 72) | def calculate_i3d_activations(video1, video2, i3d_model, device):
  function calculate_vfid (line 87) | def calculate_vfid(real_activations, fake_activations):
  function calculate_frechet_distance (line 101) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
  function get_i3d_activations (line 155) | def get_i3d_activations(batched_video,
  class MaxPool3dSamePadding (line 197) | class MaxPool3dSamePadding(nn.MaxPool3d):
    method compute_pad (line 198) | def compute_pad(self, dim, s):
    method forward (line 204) | def forward(self, x):
  class Unit3D (line 223) | class Unit3D(nn.Module):
    method __init__ (line 224) | def __init__(self,
    method compute_pad (line 260) | def compute_pad(self, dim, s):
    method forward (line 266) | def forward(self, x):
  class InceptionModule (line 291) | class InceptionModule(nn.Module):
    method __init__ (line 292) | def __init__(self, in_channels, out_channels, name):
    method forward (line 328) | def forward(self, x):
  class InceptionI3d (line 336) | class InceptionI3d(nn.Module):
    method __init__ (line 373) | def __init__(self,
    method replace_logits (line 535) | def replace_logits(self, num_classes):
    method build (line 546) | def build(self):
    method forward (line 550) | def forward(self, x):
    method extract_features (line 562) | def extract_features(self, x, target_endpoint='Logits'):

FILE: propainter/core/prefetch_dataloader.py
  class PrefetchGenerator (line 7) | class PrefetchGenerator(threading.Thread):
    method __init__ (line 18) | def __init__(self, generator, num_prefetch_queue):
    method run (line 25) | def run(self):
    method __next__ (line 30) | def __next__(self):
    method __iter__ (line 36) | def __iter__(self):
  class PrefetchDataLoader (line 40) | class PrefetchDataLoader(DataLoader):
    method __init__ (line 55) | def __init__(self, num_prefetch_queue, **kwargs):
    method __iter__ (line 59) | def __iter__(self):
  class CPUPrefetcher (line 63) | class CPUPrefetcher():
    method __init__ (line 70) | def __init__(self, loader):
    method next (line 74) | def next(self):
    method reset (line 80) | def reset(self):
  class CUDAPrefetcher (line 84) | class CUDAPrefetcher():
    method __init__ (line 97) | def __init__(self, loader, opt):
    method preload (line 105) | def preload(self):
    method next (line 117) | def next(self):
    method reset (line 123) | def reset(self):

FILE: propainter/core/trainer.py
  class Trainer (line 26) | class Trainer:
    method __init__ (line 27) | def __init__(self, config):
    method setup_optimizers (line 129) | def setup_optimizers(self):
    method setup_schedulers (line 156) | def setup_schedulers(self):
    method update_learning_rate (line 187) | def update_learning_rate(self):
    method get_lr (line 193) | def get_lr(self):
    method add_summary (line 197) | def add_summary(self, writer, name, val):
    method load (line 207) | def load(self):
    method save (line 274) | def save(self, it):
    method train (line 321) | def train(self):
    method _train_epoch (line 350) | def _train_epoch(self, pbar):

FILE: propainter/core/trainer_flow_w_edge.py
  class Trainer (line 26) | class Trainer:
    method __init__ (line 27) | def __init__(self, config):
    method setup_optimizers (line 88) | def setup_optimizers(self):
    method setup_schedulers (line 109) | def setup_schedulers(self):
    method update_learning_rate (line 128) | def update_learning_rate(self):
    method get_lr (line 132) | def get_lr(self):
    method add_summary (line 136) | def add_summary(self, writer, name, val):
    method load (line 146) | def load(self):
    method save (line 183) | def save(self, it):
    method train (line 212) | def train(self):
    method get_edges (line 261) | def get_edges(self, flows):
    method _train_epoch (line 275) | def _train_epoch(self, pbar):

FILE: propainter/core/utils.py
  function read_dirnames_under_root (line 24) | def read_dirnames_under_root(root_dir):
  class TrainZipReader (line 33) | class TrainZipReader(object):
    method __init__ (line 36) | def __init__(self):
    method build_file_dict (line 40) | def build_file_dict(path):
    method imread (line 50) | def imread(path, idx):
  class TestZipReader (line 60) | class TestZipReader(object):
    method __init__ (line 63) | def __init__(self):
    method build_file_dict (line 67) | def build_file_dict(path):
    method imread (line 77) | def imread(path, idx):
  function to_tensors (line 94) | def to_tensors():
  class GroupRandomHorizontalFlowFlip (line 98) | class GroupRandomHorizontalFlowFlip(object):
    method __call__ (line 101) | def __call__(self, img_group, flowF_group, flowB_group):
  class GroupRandomHorizontalFlip (line 114) | class GroupRandomHorizontalFlip(object):
    method __call__ (line 117) | def __call__(self, img_group, is_flow=False):
  class Stack (line 130) | class Stack(object):
    method __init__ (line 131) | def __init__(self, roll=False):
    method __call__ (line 134) | def __call__(self, img_group):
  class ToTorchFormatTensor (line 151) | class ToTorchFormatTensor(object):
    method __init__ (line 154) | def __init__(self, div=True):
    method __call__ (line 157) | def __call__(self, pic):
  function create_random_shape_with_random_motion (line 178) | def create_random_shape_with_random_motion(video_length,
  function create_random_shape_with_random_motion_zoom_rotation (line 220) | def create_random_shape_with_random_motion_zoom_rotation(video_length, z...
  function get_random_shape (line 268) | def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
  function random_accelerate (line 309) | def random_accelerate(velocity, maxAcceleration, dist='uniform'):
  function get_random_velocity (line 324) | def get_random_velocity(max_speed=3, dist='uniform'):
  function random_move_control_points (line 336) | def random_move_control_points(X,

FILE: propainter/inference.py
  function resize_frames (line 35) | def resize_frames(frames, size=None):
  function read_frame_from_videos (line 50) | def read_frame_from_videos(frame_root, video_length):
  function binary_mask (line 72) | def binary_mask(mask, th=0.1):
  function read_mask (line 78) | def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates...
  function get_ref_index (line 133) | def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ...
  function file_exists (line 149) | def file_exists(directory, filename):
  class Propainter (line 155) | class Propainter:
    method __init__ (line 156) | def __init__(self, device):
    method load_propainter (line 159) | def load_propainter(self,fix_raft_path,flow_path,ProPainter_path):
    method to (line 193) | def to(self, device):
    method forward (line 199) | def forward(self, video, mask,load_videobypath=False, resize_ratio=1.0...

FILE: propainter/model/canny/canny_filter.py
  function rgb_to_grayscale (line 12) | def rgb_to_grayscale(image, rgb_weights = None):
  function canny (line 38) | def canny(
  class Canny (line 178) | class Canny(nn.Module):
    method __init__ (line 204) | def __init__(
    method __repr__ (line 242) | def __repr__(self) -> str:
    method forward (line 253) | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Te...

FILE: propainter/model/canny/filter.py
  function _compute_padding (line 9) | def _compute_padding(kernel_size: List[int]) -> List[int]:
  function filter2d (line 32) | def filter2d(
  function filter2d_separable (line 135) | def filter2d_separable(
  function filter3d (line 189) | def filter3d(

FILE: propainter/model/canny/gaussian.py
  function gaussian_blur2d (line 10) | def gaussian_blur2d(
  class GaussianBlur2d (line 56) | class GaussianBlur2d(nn.Module):
    method __init__ (line 86) | def __init__(
    method __repr__ (line 99) | def __repr__(self) -> str:
    method forward (line 115) | def forward(self, input: torch.Tensor) -> torch.Tensor:

FILE: propainter/model/canny/kernels.py
  function normalize_kernel2d (line 8) | def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor:
  function gaussian (line 16) | def gaussian(window_size: int, sigma: float) -> torch.Tensor:
  function gaussian_discrete_erf (line 28) | def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor:
  function _modified_bessel_0 (line 43) | def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:
  function _modified_bessel_1 (line 60) | def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:
  function _modified_bessel_i (line 77) | def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:
  function gaussian_discrete (line 106) | def gaussian_discrete(window_size, sigma) -> torch.Tensor:
  function laplacian_1d (line 127) | def laplacian_1d(window_size) -> torch.Tensor:
  function get_box_kernel2d (line 136) | def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor:
  function get_binary_kernel2d (line 145) | def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor:
  function get_sobel_kernel_3x3 (line 157) | def get_sobel_kernel_3x3() -> torch.Tensor:
  function get_sobel_kernel_5x5_2nd_order (line 162) | def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor:
  function _get_sobel_kernel_5x5_2nd_order_xy (line 175) | def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor:
  function get_diff_kernel_3x3 (line 188) | def get_diff_kernel_3x3() -> torch.Tensor:
  function get_diff_kernel3d (line 193) | def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> ...
  function get_diff_kernel3d_2nd_order (line 219) | def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch....
  function get_sobel_kernel2d (line 260) | def get_sobel_kernel2d() -> torch.Tensor:
  function get_diff_kernel2d (line 266) | def get_diff_kernel2d() -> torch.Tensor:
  function get_sobel_kernel2d_2nd_order (line 272) | def get_sobel_kernel2d_2nd_order() -> torch.Tensor:
  function get_diff_kernel2d_2nd_order (line 279) | def get_diff_kernel2d_2nd_order() -> torch.Tensor:
  function get_spatial_gradient_kernel2d (line 286) | def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor:
  function get_spatial_gradient_kernel3d (line 318) | def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.de...
  function get_gaussian_kernel1d (line 346) | def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bo...
  function get_gaussian_discrete_kernel1d (line 374) | def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force...
  function get_gaussian_erf_kernel1d (line 403) | def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even...
  function get_gaussian_kernel2d (line 432) | def get_gaussian_kernel2d(
  function get_laplacian_kernel1d (line 472) | def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor:
  function get_laplacian_kernel2d (line 496) | def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor:
  function get_pascal_kernel_2d (line 530) | def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.T...
  function get_pascal_kernel_1d (line 562) | def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch....
  function get_canny_nms_kernel (line 604) | def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) ...
  function get_hysteresis_kernel (line 623) | def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float)...
  function get_hanning_kernel1d (line 642) | def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), d...
  function get_hanning_kernel2d (line 672) | def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.devi...

FILE: propainter/model/canny/sobel.py
  function spatial_gradient (line 8) | def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: in...
  function spatial_gradient3d (line 58) | def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: i...
  function sobel (line 122) | def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-...
  class SpatialGradient (line 164) | class SpatialGradient(nn.Module):
    method __init__ (line 184) | def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bo...
    method __repr__ (line 190) | def __repr__(self) -> str:
    method forward (line 196) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class SpatialGradient3d (line 200) | class SpatialGradient3d(nn.Module):
    method __init__ (line 221) | def __init__(self, mode: str = 'diff', order: int = 1) -> None:
    method __repr__ (line 228) | def __repr__(self) -> str:
    method forward (line 231) | def forward(self, input: torch.Tensor) -> torch.Tensor:  # type: ignore
  class Sobel (line 235) | class Sobel(nn.Module):
    method __init__ (line 254) | def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None:
    method __repr__ (line 259) | def __repr__(self) -> str:
    method forward (line 262) | def forward(self, input: torch.Tensor) -> torch.Tensor:

FILE: propainter/model/misc.py
  function constant_init (line 12) | def constant_init(module, val, bias=0):
  function get_root_logger (line 19) | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_f...
  function gpu_is_available (line 61) | def gpu_is_available():
  function get_device (line 67) | def get_device(gpu_id=None):
  function set_random_seed (line 81) | def set_random_seed(seed):
  function get_time_str (line 90) | def get_time_str():
  function scandir (line 94) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):

FILE: propainter/model/modules/base_module.py
  class BaseNetwork (line 7) | class BaseNetwork(nn.Module):
    method __init__ (line 8) | def __init__(self):
    method print_network (line 11) | def print_network(self):
    method init_weights (line 22) | def init_weights(self, init_type='normal', gain=0.02):
  class Vec2Feat (line 64) | class Vec2Feat(nn.Module):
    method __init__ (line 65) | def __init__(self, channel, hidden, kernel_size, stride, padding):
    method forward (line 79) | def forward(self, x, t, output_size):
  class FusionFeedForward (line 94) | class FusionFeedForward(nn.Module):
    method __init__ (line 95) | def __init__(self, dim, hidden_dim=1960, t2t_params=None):
    method forward (line 104) | def forward(self, x, output_size):

FILE: propainter/model/modules/deformconv.py
  class ModulatedDeformConv2d (line 7) | class ModulatedDeformConv2d(nn.Module):
    method __init__ (line 8) | def __init__(self,
    method init_weights (line 40) | def init_weights(self):
    method forward (line 53) | def forward(self, x, offset, mask):

FILE: propainter/model/modules/flow_comp_raft.py
  function initialize_RAFT (line 15) | def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
  class RAFT_bi (line 32) | class RAFT_bi(nn.Module):
    method __init__ (line 34) | def __init__(self, model_path='weights/raft-things.pth', device='cuda'):
    method forward (line 44) | def forward(self, gt_local_frames, iters=20):
  function smoothness_loss (line 64) | def smoothness_loss(flow, cmask):
  function smoothness_deltas (line 71) | def smoothness_deltas(flow):
  function second_order_loss (line 92) | def second_order_loss(flow, cmask):
  function charbonnier_loss (line 99) | def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, ...
  function second_order_deltas (line 118) | def second_order_deltas(flow):
  function create_mask (line 147) | def create_mask(tensor, paddings):
  function ternary_loss (line 168) | def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, s...
  class FlowLoss (line 178) | class FlowLoss(nn.Module):
    method __init__ (line 179) | def __init__(self):
    method forward (line 183) | def forward(self, pred_flows, gt_flows, masks, frames):
  function edgeLoss (line 212) | def edgeLoss(preds_edges, edges):
  class EdgeLoss (line 233) | class EdgeLoss(nn.Module):
    method __init__ (line 234) | def __init__(self):
    method forward (line 237) | def forward(self, pred_edges, gt_edges, masks):
  class FlowSimpleLoss (line 252) | class FlowSimpleLoss(nn.Module):
    method __init__ (line 253) | def __init__(self):
    method forward (line 257) | def forward(self, pred_flows, gt_flows):

FILE: propainter/model/modules/flow_loss_utils.py
  function flow_warp (line 6) | def flow_warp(x,
  function length_sq (line 62) | def length_sq(x):
  function fbConsistencyCheck (line 66) | def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
  function rgb2gray (line 83) | def rgb2gray(image):
  function ternary_transform (line 89) | def ternary_transform(image, max_distance=1):
  function hamming_distance (line 102) | def hamming_distance(t1, t2):
  function create_mask (line 109) | def create_mask(mask, paddings):
  function ternary_loss2 (line 124) | def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1):

FILE: propainter/model/modules/sparse_transformer.py
  class SoftSplit (line 7) | class SoftSplit(nn.Module):
    method __init__ (line 8) | def __init__(self, channel, hidden, kernel_size, stride, padding):
    method forward (line 19) | def forward(self, x, b, output_size):
  class SoftComp (line 34) | class SoftComp(nn.Module):
    method __init__ (line 35) | def __init__(self, channel, hidden, kernel_size, stride, padding):
    method forward (line 49) | def forward(self, x, t, output_size):
  class FusionFeedForward (line 64) | class FusionFeedForward(nn.Module):
    method __init__ (line 65) | def __init__(self, dim, hidden_dim=1960, t2t_params=None):
    method forward (line 74) | def forward(self, x, output_size):
  function window_partition (line 104) | def window_partition(x, window_size, n_head):
  class SparseWindowAttention (line 117) | class SparseWindowAttention(nn.Module):
    method __init__ (line 118) | def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias...
    method forward (line 158) | def forward(self, x, mask=None, T_ind=None, attn_mask=None):
  class TemporalSparseTransformer (line 284) | class TemporalSparseTransformer(nn.Module):
    method __init__ (line 285) | def __init__(self, dim, n_head, window_size, pool_size,
    method forward (line 294) | def forward(self, x, fold_x_size, mask=None, T_ind=None):
  class TemporalSparseTransformerBlock (line 317) | class TemporalSparseTransformerBlock(nn.Module):
    method __init__ (line 318) | def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_pa...
    method forward (line 328) | def forward(self, x, fold_x_size, l_mask=None, t_dilation=2):

FILE: propainter/model/modules/spectral_norm.py
  class SpectralNorm (line 8) | class SpectralNorm(object):
    method __init__ (line 20) | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-...
    method reshape_weight_to_matrix (line 30) | def reshape_weight_to_matrix(self, weight):
    method compute_weight (line 40) | def compute_weight(self, module, do_power_iteration):
    method remove (line 98) | def remove(self, module):
    method __call__ (line 108) | def __call__(self, module, inputs):
    method _solve_v_and_rescale (line 113) | def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
    method apply (line 122) | def apply(module, name, n_power_iterations, dim, eps):
  class SpectralNormLoadStateDictPreHook (line 161) | class SpectralNormLoadStateDictPreHook(object):
    method __init__ (line 163) | def __init__(self, fn):
    method __call__ (line 174) | def __call__(self, state_dict, prefix, local_metadata, strict,
  class SpectralNormStateDictHook (line 192) | class SpectralNormStateDictHook(object):
    method __init__ (line 194) | def __init__(self, fn):
    method __call__ (line 197) | def __call__(self, module, state_dict, prefix, local_metadata):
  function spectral_norm (line 207) | def spectral_norm(module,
  function remove_spectral_norm (line 264) | def remove_spectral_norm(module, name='weight'):
  function use_spectral_norm (line 285) | def use_spectral_norm(module, use_sn=False):

FILE: propainter/model/propainter.py
  function length_sq (line 29) | def length_sq(x):
  function fbConsistencyCheck (line 32) | def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): #debug
  class DeformableAlignment (line 44) | class DeformableAlignment(ModulatedDeformConv2d):
    method __init__ (line 46) | def __init__(self, *args, **kwargs):
    method init_offset (line 63) | def init_offset(self):
    method forward (line 66) | def forward(self, x, cond_feat, flow):
  class BidirectionalPropagation (line 82) | class BidirectionalPropagation(nn.Module):
    method __init__ (line 83) | def __init__(self, channel, learnable=True):
    method binary_mask (line 108) | def binary_mask(self, mask, th=0.1):
    method forward (line 114) | def forward(self, x, flows_forward, flows_backward, mask, interpolatio...
  class Encoder (line 214) | class Encoder(nn.Module):
    method __init__ (line 215) | def __init__(self):
    method forward (line 239) | def forward(self, x):
  class deconv (line 256) | class deconv(nn.Module):
    method __init__ (line 257) | def __init__(self,
    method forward (line 269) | def forward(self, x):
  class InpaintGenerator (line 277) | class InpaintGenerator(BaseNetwork):
    method __init__ (line 278) | def __init__(self, init_weights=True, model_path=None):
    method img_propagation (line 336) | def img_propagation(self, masked_frames, completed_flows, masks, inter...
    method forward (line 340) | def forward(self, masked_frames, completed_flows, masks_in, masks_upda...
  class Discriminator (line 399) | class Discriminator(BaseNetwork):
    method __init__ (line 400) | def __init__(self,
    method forward (line 464) | def forward(self, xs):
  class Discriminator_2D (line 475) | class Discriminator_2D(BaseNetwork):
    method __init__ (line 476) | def __init__(self,
    method forward (line 540) | def forward(self, xs):
  function spectral_norm (line 550) | def spectral_norm(module, mode=True):

FILE: propainter/model/recurrent_flow_completion.py
  class SecondOrderDeformableAlignment (line 14) | class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
    method __init__ (line 16) | def __init__(self, *args, **kwargs):
    method init_offset (line 32) | def init_offset(self):
    method forward (line 35) | def forward(self, x, extra_feat):
  class BidirectionalPropagation (line 51) | class BidirectionalPropagation(nn.Module):
    method __init__ (line 52) | def __init__(self, channel):
    method forward (line 71) | def forward(self, x):
  class deconv (line 132) | class deconv(nn.Module):
    method __init__ (line 133) | def __init__(self,
    method forward (line 145) | def forward(self, x):
  class P3DBlock (line 153) | class P3DBlock(nn.Module):
    method __init__ (line 154) | def __init__(self, in_channels, out_channels, kernel_size, stride, pad...
    method forward (line 167) | def forward(self, feats):
  class EdgeDetection (line 177) | class EdgeDetection(nn.Module):
    method __init__ (line 178) | def __init__(self, in_ch=2, out_ch=1, mid_ch=16):
    method forward (line 198) | def forward(self, flow):
  class RecurrentFlowCompleteNet (line 208) | class RecurrentFlowCompleteNet(nn.Module):
    method __init__ (line 209) | def __init__(self, model_path=None):
    method forward (line 277) | def forward(self, masked_flows, masks):
    method forward_bidirect_flow (line 317) | def forward_bidirect_flow(self, masked_flows_bi, masks):
    method combine_flow (line 345) | def combine_flow(self, masked_flows_bi, pred_flows_bi, masks):

FILE: propainter/model/vgg_arch.py
  function insert_bn (line 34) | def insert_bn(names):
  class VGGFeatureExtractor (line 51) | class VGGFeatureExtractor(nn.Module):
    method __init__ (line 74) | def __init__(self,
    method forward (line 137) | def forward(self, x):

FILE: propainter/utils/download_util.py
  function sizeof_fmt (line 8) | def sizeof_fmt(size, suffix='B'):
  function download_file_from_google_drive (line 25) | def download_file_from_google_drive(file_id, save_path):
  function get_confirm_token (line 55) | def get_confirm_token(response):
  function save_response_content (line 62) | def save_response_content(response, destination, file_size=None, chunk_s...
  function load_file_from_url (line 83) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):

FILE: propainter/utils/file_client.py
  class BaseStorageBackend (line 4) | class BaseStorageBackend(metaclass=ABCMeta):
    method get (line 13) | def get(self, filepath):
    method get_text (line 17) | def get_text(self, filepath):
  class MemcachedBackend (line 21) | class MemcachedBackend(BaseStorageBackend):
    method __init__ (line 31) | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
    method get (line 46) | def get(self, filepath):
    method get_text (line 53) | def get_text(self, filepath):
  class HardDiskBackend (line 57) | class HardDiskBackend(BaseStorageBackend):
    method get (line 60) | def get(self, filepath):
    method get_text (line 66) | def get_text(self, filepath):
  class LmdbBackend (line 73) | class LmdbBackend(BaseStorageBackend):
    method __init__ (line 93) | def __init__(self, db_paths, client_keys='default', readonly=True, loc...
    method get (line 113) | def get(self, filepath, client_key):
    method get_text (line 127) | def get_text(self, filepath):
  class FileClient (line 131) | class FileClient(object):
    method __init__ (line 150) | def __init__(self, backend='disk', **kwargs):
    method get (line 157) | def get(self, filepath, client_key='default'):
    method get_text (line 165) | def get_text(self, filepath):

FILE: propainter/utils/flow_util.py
  function resize_flow (line 6) | def resize_flow(flow, newh, neww):
  function resize_flow_pytorch (line 13) | def resize_flow_pytorch(flow, newh, neww):
  function imwrite (line 21) | def imwrite(img, file_path, params=None, auto_mkdir=True):
  function flowread (line 28) | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
  function flowwrite (line 67) | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kw...
  function quantize_flow (line 102) | def quantize_flow(flow, max_val=0.02, norm=True):
  function dequantize_flow (line 128) | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
  function quantize (line 152) | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
  function dequantize (line 176) | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):

FILE: propainter/utils/img_util.py
  function img2tensor (line 9) | def img2tensor(imgs, bgr2rgb=True, float32=True):
  function tensor2img (line 38) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
  function tensor2img_fast (line 97) | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
  function imfrombytes (line 114) | def imfrombytes(content, flag='color', float32=False):
  function imwrite (line 135) | def imwrite(img, file_path, params=None, auto_mkdir=True):
  function crop_border (line 154) | def crop_border(imgs, crop_border):

FILE: run_diffueraser.py
  function load_diffueraser (line 11) | def load_diffueraser(sd_repo,pre_model_path, ckpt_path,original_config_f...
  function load_propainter (line 18) | def load_propainter(fix_raft_path,flow_path,ProPainter_path,device="cpu"):
Condensed preview — 77 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,946K chars).
[
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2025 smthemex\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 3596,
    "preview": "# ComfyUI_DiffuEraser\n[DiffuEraser](https://github.com/lixiaowen-xw/DiffuEraser) is  a diffusion model for video Inpaint"
  },
  {
    "path": "__init__.py",
    "chars": 33,
    "preview": "\r\nfrom .diffueraser_node import *"
  },
  {
    "path": "diffueraser_node.py",
    "chars": 12438,
    "preview": "# !/usr/bin/env python\r\n# -*- coding: UTF-8 -*-\r\nimport os\r\nimport torch\r\nimport gc\r\nimport numpy as np\r\nfrom typing_ext"
  },
  {
    "path": "example_workflows/differaser.json",
    "chars": 15418,
    "preview": "{\n  \"id\": \"3da45669-6ef0-4ec2-a292-abe74e953ca2\",\n  \"revision\": 0,\n  \"last_node_id\": 24,\n  \"last_link_id\": 28,\n  \"nodes\""
  },
  {
    "path": "libs/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "libs/brushnet_CA.py",
    "chars": 47366,
    "preview": "from dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch im"
  },
  {
    "path": "libs/diffueraser.py",
    "chars": 23628,
    "preview": "import gc\nimport copy\nimport cv2\nimport os\nimport numpy as np\nimport torch\nimport torchvision\nimport re\nimport random\nfr"
  },
  {
    "path": "libs/pipeline_diffueraser.py",
    "chars": 69145,
    "preview": "import inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nimport numpy as np\nimport PIL.Image\n"
  },
  {
    "path": "libs/transformer_temporal.py",
    "chars": 16758,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "libs/unet_2d_blocks.py",
    "chars": 148985,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "libs/unet_2d_condition.py",
    "chars": 70141,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "libs/unet_3d_blocks.py",
    "chars": 92950,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "libs/unet_motion_model.py",
    "chars": 45146,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "libs/v1-inference.yaml",
    "chars": 1943,
    "preview": "model:\r\n  base_learning_rate: 1.0e-04\r\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\r\n  params:\r\n    linear_start:"
  },
  {
    "path": "node_utils.py",
    "chars": 9789,
    "preview": "# !/usr/bin/env python\r\n# -*- coding: UTF-8 -*-\r\nimport os\r\nimport torch\r\nfrom PIL import Image\r\nimport numpy as np\r\nimp"
  },
  {
    "path": "propainter/RAFT/__init__.py",
    "chars": 55,
    "preview": "# from .demo import RAFT_infer\nfrom .raft import RAFT\n\n"
  },
  {
    "path": "propainter/RAFT/corr.py",
    "chars": 3640,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt"
  },
  {
    "path": "propainter/RAFT/datasets.py",
    "chars": 9247,
    "preview": "# Data loading based on https://github.com/NVIDIA/flownet2-pytorch\n\nimport numpy as np\nimport torch\nimport torch.utils.d"
  },
  {
    "path": "propainter/RAFT/demo.py",
    "chars": 1856,
    "preview": "import sys\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom "
  },
  {
    "path": "propainter/RAFT/extractor.py",
    "chars": 8847,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(se"
  },
  {
    "path": "propainter/RAFT/raft.py",
    "chars": 4870,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .update import BasicUpdateBl"
  },
  {
    "path": "propainter/RAFT/update.py",
    "chars": 5227,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, i"
  },
  {
    "path": "propainter/RAFT/utils/__init__.py",
    "chars": 71,
    "preview": "from .flow_viz import flow_to_image\nfrom .frame_utils import writeFlow\n"
  },
  {
    "path": "propainter/RAFT/utils/augmentor.py",
    "chars": 9108,
    "preview": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL"
  },
  {
    "path": "propainter/RAFT/utils/flow_viz.py",
    "chars": 4318,
    "preview": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright "
  },
  {
    "path": "propainter/RAFT/utils/flow_viz_pt.py",
    "chars": 3856,
    "preview": "# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization\nimport torch\ntorch.pi = tor"
  },
  {
    "path": "propainter/RAFT/utils/frame_utils.py",
    "chars": 4024,
    "preview": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUse"
  },
  {
    "path": "propainter/RAFT/utils/utils.py",
    "chars": 2451,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \""
  },
  {
    "path": "propainter/core/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "propainter/core/dataset.py",
    "chars": 9226,
    "preview": "import os\nimport json\nimport random\n\nimport cv2\nfrom PIL import Image\nimport numpy as np\n\nimport torch\nimport torchvisio"
  },
  {
    "path": "propainter/core/dist.py",
    "chars": 1459,
    "preview": "import os\nimport torch\n\n\ndef get_world_size():\n    \"\"\"Find OMPI world size without calling mpi functions\n    :rtype: int"
  },
  {
    "path": "propainter/core/loss.py",
    "chars": 6697,
    "preview": "import torch\nimport torch.nn as nn\nimport lpips\nfrom ..model.vgg_arch import VGGFeatureExtractor\n\nclass PerceptualLoss(n"
  },
  {
    "path": "propainter/core/lr_scheduler.py",
    "chars": 4386,
    "preview": "\"\"\"\n    LR scheduler from BasicSR https://github.com/xinntao/BasicSR\n\"\"\"\nimport math\nfrom collections import Counter\nfro"
  },
  {
    "path": "propainter/core/metrics.py",
    "chars": 20682,
    "preview": "import numpy as np\n# from skimage import measure\nfrom skimage.metrics import structural_similarity as compare_ssim\nfrom "
  },
  {
    "path": "propainter/core/prefetch_dataloader.py",
    "chars": 3131,
    "preview": "import queue as Queue\nimport threading\nimport torch\nfrom torch.utils.data import DataLoader\n\n\nclass PrefetchGenerator(th"
  },
  {
    "path": "propainter/core/trainer.py",
    "chars": 23545,
    "preview": "import os\nimport glob\nimport logging\nimport importlib\nfrom tqdm import tqdm\n\nimport torch\nimport torch.nn as nn\nimport t"
  },
  {
    "path": "propainter/core/trainer_flow_w_edge.py",
    "chars": 15625,
    "preview": "import os\nimport glob\nimport logging\nimport importlib\nfrom tqdm import tqdm\n\nimport torch\nimport torch.nn as nn\nimport t"
  },
  {
    "path": "propainter/core/utils.py",
    "chars": 14236,
    "preview": "import os\nimport io\nimport cv2\nimport random\nimport numpy as np\nfrom PIL import Image, ImageOps\nimport zipfile\nimport ma"
  },
  {
    "path": "propainter/inference.py",
    "chars": 28925,
    "preview": "# -*- coding: utf-8 -*-\nimport os\nimport cv2\nimport numpy as np\nimport scipy.ndimage\nfrom PIL import Image\nfrom tqdm imp"
  },
  {
    "path": "propainter/model/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "propainter/model/canny/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "propainter/model/canny/canny_filter.py",
    "chars": 9404,
    "preview": "import math\nfrom typing import Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .gaussian"
  },
  {
    "path": "propainter/model/canny/filter.py",
    "chars": 11036,
    "preview": "from typing import List\n\nimport torch\nimport torch.nn.functional as F\n\nfrom .kernels import normalize_kernel2d\n\n\ndef _co"
  },
  {
    "path": "propainter/model/canny/gaussian.py",
    "chars": 3815,
    "preview": "from typing import Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom .filter import filter2d, filter2d_separable\nfrom .ker"
  },
  {
    "path": "propainter/model/canny/kernels.py",
    "chars": 25239,
    "preview": "import math\nfrom math import sqrt\nfrom typing import List, Optional, Tuple\n\nimport torch\n\n\ndef normalize_kernel2d(input:"
  },
  {
    "path": "propainter/model/canny/sobel.py",
    "chars": 9397,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .kernels import get_spatial_gradient_kernel2d, "
  },
  {
    "path": "propainter/model/misc.py",
    "chars": 4652,
    "preview": "import os\nimport re\nimport random\nimport time\nimport torch\nimport torch.nn as nn\nimport logging\nimport numpy as np\nfrom "
  },
  {
    "path": "propainter/model/modules/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "propainter/model/modules/base_module.py",
    "chars": 5583,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom functools import reduce\n\nclass BaseNetwork(nn.M"
  },
  {
    "path": "propainter/model/modules/deformconv.py",
    "chars": 1695,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init as init\nfrom torch.nn.modules.utils import _pair, _single\ni"
  },
  {
    "path": "propainter/model/modules/flow_comp_raft.py",
    "chars": 10205,
    "preview": "import argparse\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nfrom ...RAFT import RAFT\nfrom .flow"
  },
  {
    "path": "propainter/model/modules/flow_loss_utils.py",
    "chars": 5797,
    "preview": "import torch\r\nimport numpy as np\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\ndef flow_warp(x,\r\n          "
  },
  {
    "path": "propainter/model/modules/sparse_transformer.py",
    "chars": 15866,
    "preview": "import math\nfrom functools import reduce\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass SoftS"
  },
  {
    "path": "propainter/model/modules/spectral_norm.py",
    "chars": 12357,
    "preview": "\"\"\"\nSpectral Normalization from https://arxiv.org/abs/1802.05957\n\"\"\"\nimport torch\nfrom torch.nn.functional import normal"
  },
  {
    "path": "propainter/model/propainter.py",
    "chars": 23207,
    "preview": "''' Towards An End-to-End Framework for Video Inpainting\n'''\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functio"
  },
  {
    "path": "propainter/model/recurrent_flow_completion.py",
    "chars": 12964,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\n\n\nfrom .modules.deformconv import "
  },
  {
    "path": "propainter/model/vgg_arch.py",
    "chars": 6065,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom torch import nn as nn\nfrom torchvision.models import vgg"
  },
  {
    "path": "propainter/utils/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "propainter/utils/download_util.py",
    "chars": 3746,
    "preview": "import math\nimport os\nimport requests\nfrom torch.hub import download_url_to_file, get_dir\nfrom tqdm import tqdm\nfrom url"
  },
  {
    "path": "propainter/utils/file_client.py",
    "chars": 5913,
    "preview": "from abc import ABCMeta, abstractmethod\n\n\nclass BaseStorageBackend(metaclass=ABCMeta):\n    \"\"\"Abstract class of storage "
  },
  {
    "path": "propainter/utils/flow_util.py",
    "chars": 7030,
    "preview": "import cv2\nimport numpy as np\nimport os\nimport torch.nn.functional as F\n\ndef resize_flow(flow, newh, neww):\n    oldh, ol"
  },
  {
    "path": "propainter/utils/img_util.py",
    "chars": 6134,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport torch\nfrom torchvision.utils import make_grid\n\n\ndef img2tenso"
  },
  {
    "path": "pyproject.toml",
    "chars": 710,
    "preview": "[project]\r\nname = \"comfyui_diffueraser\"\r\ndescription = \"DiffuEraser is a diffusion model for video Inpainting, you can u"
  },
  {
    "path": "requirements.txt",
    "chars": 220,
    "preview": "torch\ntorchvision\ntorchaudio\ndiffusers\naccelerate\nopencv-python\nimageio\n#matplotlib\ntransformers\neinops\n#datasets\n#numpy"
  },
  {
    "path": "run_diffueraser.py",
    "chars": 644,
    "preview": "import torch\nimport os \nimport time\nimport random\nfrom .libs.diffueraser import DiffuEraser\nfrom .propainter.inference i"
  },
  {
    "path": "sd15_repo/feature_extractor/preprocessor_config.json",
    "chars": 520,
    "preview": "{\n  \"crop_size\": {\n    \"height\": 224,\n    \"width\": 224\n  },\n  \"do_center_crop\": true,\n  \"do_convert_rgb\": true,\n  \"do_no"
  },
  {
    "path": "sd15_repo/model_index.json",
    "chars": 642,
    "preview": "{\n  \"_class_name\": \"StableDiffusionPipeline\",\n  \"_diffusers_version\": \"0.21.0.dev0\",\n  \"_name_or_path\": \"lykon-models/dr"
  },
  {
    "path": "sd15_repo/safety_checker/config.json",
    "chars": 796,
    "preview": "{\n  \"_name_or_path\": \"/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f4818324"
  },
  {
    "path": "sd15_repo/scheduler/scheduler_config.json",
    "chars": 614,
    "preview": "{\n  \"_class_name\": \"DEISMultistepScheduler\",\n  \"_diffusers_version\": \"0.21.0.dev0\",\n  \"algorithm_type\": \"deis\",\n  \"beta_"
  },
  {
    "path": "sd15_repo/text_encoder/config.json",
    "chars": 724,
    "preview": "{\n  \"_name_or_path\": \"/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f4818324"
  },
  {
    "path": "sd15_repo/tokenizer/merges.txt",
    "chars": 564203,
    "preview": "#version: 0.2\r\ni n\r\nt h\r\na n\r\nr e\r\na r\r\ne r\r\nth e</w>\r\nin g</w>\r\no u\r\no n\r\ns t\r\no r\r\ne n\r\no n</w>\r\na l\r\na t\r\ne r</w>\r\ni "
  },
  {
    "path": "sd15_repo/tokenizer/special_tokens_map.json",
    "chars": 496,
    "preview": "{\r\n  \"bos_token\": {\r\n    \"content\": \"<|startoftext|>\",\r\n    \"lstrip\": false,\r\n    \"normalized\": true,\r\n    \"rstrip\": fal"
  },
  {
    "path": "sd15_repo/tokenizer/tokenizer_config.json",
    "chars": 737,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"bos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<|startoftext|>\",\n    \"lstri"
  },
  {
    "path": "sd15_repo/tokenizer/vocab.json",
    "chars": 1099737,
    "preview": "{\r\n  \"!\": 0,\r\n  \"!!\": 1443,\r\n  \"!!!\": 11194,\r\n  \"!!!!\": 4003,\r\n  \"!!!!!!!!\": 11281,\r\n  \"!!!!!!!!!!!!!!!!\": 30146,\r\n  \"!!"
  },
  {
    "path": "sd15_repo/unet/config.json",
    "chars": 1868,
    "preview": "{\n  \"_class_name\": \"UNet2DConditionModel\",\n  \"_diffusers_version\": \"0.21.0.dev0\",\n  \"_name_or_path\": \"/home/patrick/.cac"
  },
  {
    "path": "sd15_repo/vae/config.json",
    "chars": 756,
    "preview": "{\n  \"_class_name\": \"AutoencoderKL\",\n  \"_diffusers_version\": \"0.21.0.dev0\",\n  \"_name_or_path\": \"/home/patrick/.cache/hugg"
  }
]

About this extraction

This page contains the full source code of the smthemex/ComfyUI_DiffuEraser GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 77 files (2.5 MB), approximately 652.8k tokens, and a symbol index with 761 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!