Full Code of KwaiVGI/SynCamMaster for AI

main 43ea4d961189 cached
222 files
32.7 MB
4.3M tokens
2289 symbols
1 requests
Copy disabled (too large) Download .txt
Showing preview only (17,428K chars total). Download the full file to get everything.
Repository: KwaiVGI/SynCamMaster
Branch: main
Commit: 43ea4d961189
Files: 222
Total size: 32.7 MB

Directory structure:
gitextract_iigplzbn/

├── .gitignore
├── README.md
├── diffsynth/
│   ├── __init__.py
│   ├── configs/
│   │   ├── __init__.py
│   │   └── model_config.py
│   ├── controlnets/
│   │   ├── __init__.py
│   │   ├── controlnet_unit.py
│   │   └── processors.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── simple_text_image.py
│   │   └── video.py
│   ├── extensions/
│   │   ├── ESRGAN/
│   │   │   └── __init__.py
│   │   ├── FastBlend/
│   │   │   ├── __init__.py
│   │   │   ├── api.py
│   │   │   ├── cupy_kernels.py
│   │   │   ├── data.py
│   │   │   ├── patch_match.py
│   │   │   └── runners/
│   │   │       ├── __init__.py
│   │   │       ├── accurate.py
│   │   │       ├── balanced.py
│   │   │       ├── fast.py
│   │   │       └── interpolation.py
│   │   ├── ImageQualityMetric/
│   │   │   ├── BLIP/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── blip.py
│   │   │   │   ├── blip_pretrain.py
│   │   │   │   ├── med.py
│   │   │   │   └── vit.py
│   │   │   ├── __init__.py
│   │   │   ├── aesthetic.py
│   │   │   ├── clip.py
│   │   │   ├── config.py
│   │   │   ├── hps.py
│   │   │   ├── imagereward.py
│   │   │   ├── mps.py
│   │   │   ├── open_clip/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── coca_model.py
│   │   │   │   ├── constants.py
│   │   │   │   ├── factory.py
│   │   │   │   ├── generation_utils.py
│   │   │   │   ├── hf_configs.py
│   │   │   │   ├── hf_model.py
│   │   │   │   ├── loss.py
│   │   │   │   ├── model.py
│   │   │   │   ├── model_configs/
│   │   │   │   │   └── ViT-H-14.json
│   │   │   │   ├── modified_resnet.py
│   │   │   │   ├── openai.py
│   │   │   │   ├── pretrained.py
│   │   │   │   ├── push_to_hf_hub.py
│   │   │   │   ├── timm_model.py
│   │   │   │   ├── tokenizer.py
│   │   │   │   ├── transform.py
│   │   │   │   ├── transformer.py
│   │   │   │   ├── utils.py
│   │   │   │   └── version.py
│   │   │   ├── pickscore.py
│   │   │   └── trainer/
│   │   │       ├── __init__.py
│   │   │       └── models/
│   │   │           ├── __init__.py
│   │   │           ├── base_model.py
│   │   │           ├── clip_model.py
│   │   │           └── cross_modeling.py
│   │   ├── RIFE/
│   │   │   └── __init__.py
│   │   └── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── cog_dit.py
│   │   ├── cog_vae.py
│   │   ├── downloader.py
│   │   ├── flux_controlnet.py
│   │   ├── flux_dit.py
│   │   ├── flux_ipadapter.py
│   │   ├── flux_text_encoder.py
│   │   ├── flux_vae.py
│   │   ├── hunyuan_dit.py
│   │   ├── hunyuan_dit_text_encoder.py
│   │   ├── hunyuan_video_dit.py
│   │   ├── hunyuan_video_text_encoder.py
│   │   ├── hunyuan_video_vae_decoder.py
│   │   ├── hunyuan_video_vae_encoder.py
│   │   ├── kolors_text_encoder.py
│   │   ├── lora.py
│   │   ├── model_manager.py
│   │   ├── omnigen.py
│   │   ├── sd3_dit.py
│   │   ├── sd3_text_encoder.py
│   │   ├── sd3_vae_decoder.py
│   │   ├── sd3_vae_encoder.py
│   │   ├── sd_controlnet.py
│   │   ├── sd_ipadapter.py
│   │   ├── sd_motion.py
│   │   ├── sd_text_encoder.py
│   │   ├── sd_unet.py
│   │   ├── sd_vae_decoder.py
│   │   ├── sd_vae_encoder.py
│   │   ├── sdxl_controlnet.py
│   │   ├── sdxl_ipadapter.py
│   │   ├── sdxl_motion.py
│   │   ├── sdxl_text_encoder.py
│   │   ├── sdxl_unet.py
│   │   ├── sdxl_vae_decoder.py
│   │   ├── sdxl_vae_encoder.py
│   │   ├── stepvideo_dit.py
│   │   ├── stepvideo_text_encoder.py
│   │   ├── stepvideo_vae.py
│   │   ├── svd_image_encoder.py
│   │   ├── svd_unet.py
│   │   ├── svd_vae_decoder.py
│   │   ├── svd_vae_encoder.py
│   │   ├── tiler.py
│   │   ├── utils.py
│   │   ├── wan_video_dit.py
│   │   ├── wan_video_image_encoder.py
│   │   ├── wan_video_text_encoder.py
│   │   └── wan_video_vae.py
│   ├── pipelines/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── cog_video.py
│   │   ├── dancer.py
│   │   ├── flux_image.py
│   │   ├── hunyuan_image.py
│   │   ├── hunyuan_video.py
│   │   ├── omnigen_image.py
│   │   ├── pipeline_runner.py
│   │   ├── sd3_image.py
│   │   ├── sd_image.py
│   │   ├── sd_video.py
│   │   ├── sdxl_image.py
│   │   ├── sdxl_video.py
│   │   ├── step_video.py
│   │   ├── svd_video.py
│   │   ├── wan_video.py
│   │   └── wan_video_syncammaster.py
│   ├── processors/
│   │   ├── FastBlend.py
│   │   ├── PILEditor.py
│   │   ├── RIFE.py
│   │   ├── __init__.py
│   │   ├── base.py
│   │   └── sequencial_processor.py
│   ├── prompters/
│   │   ├── __init__.py
│   │   ├── base_prompter.py
│   │   ├── cog_prompter.py
│   │   ├── flux_prompter.py
│   │   ├── hunyuan_dit_prompter.py
│   │   ├── hunyuan_video_prompter.py
│   │   ├── kolors_prompter.py
│   │   ├── omnigen_prompter.py
│   │   ├── omost.py
│   │   ├── prompt_refiners.py
│   │   ├── sd3_prompter.py
│   │   ├── sd_prompter.py
│   │   ├── sdxl_prompter.py
│   │   ├── stepvideo_prompter.py
│   │   └── wan_prompter.py
│   ├── schedulers/
│   │   ├── __init__.py
│   │   ├── continuous_ode.py
│   │   ├── ddim.py
│   │   └── flow_match.py
│   ├── tokenizer_configs/
│   │   ├── __init__.py
│   │   ├── cog/
│   │   │   └── tokenizer/
│   │   │       ├── added_tokens.json
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       └── tokenizer_config.json
│   │   ├── flux/
│   │   │   ├── tokenizer_1/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   └── tokenizer_2/
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       ├── tokenizer.json
│   │   │       └── tokenizer_config.json
│   │   ├── hunyuan_dit/
│   │   │   ├── tokenizer/
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   ├── vocab.txt
│   │   │   │   └── vocab_org.txt
│   │   │   └── tokenizer_t5/
│   │   │       ├── config.json
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       └── tokenizer_config.json
│   │   ├── hunyuan_video/
│   │   │   ├── tokenizer_1/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   └── tokenizer_2/
│   │   │       ├── preprocessor_config.json
│   │   │       ├── special_tokens_map.json
│   │   │       ├── tokenizer.json
│   │   │       └── tokenizer_config.json
│   │   ├── kolors/
│   │   │   └── tokenizer/
│   │   │       ├── tokenizer.model
│   │   │       ├── tokenizer_config.json
│   │   │       └── vocab.txt
│   │   ├── stable_diffusion/
│   │   │   └── tokenizer/
│   │   │       ├── merges.txt
│   │   │       ├── special_tokens_map.json
│   │   │       ├── tokenizer_config.json
│   │   │       └── vocab.json
│   │   ├── stable_diffusion_3/
│   │   │   ├── tokenizer_1/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   ├── tokenizer_2/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   └── tokenizer_3/
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       ├── tokenizer.json
│   │   │       └── tokenizer_config.json
│   │   └── stable_diffusion_xl/
│   │       └── tokenizer_2/
│   │           ├── merges.txt
│   │           ├── special_tokens_map.json
│   │           ├── tokenizer_config.json
│   │           └── vocab.json
│   ├── trainers/
│   │   ├── __init__.py
│   │   └── text_to_image.py
│   └── vram_management/
│       ├── __init__.py
│       └── layers.py
├── download_wan2.1.py
├── example_test_data/
│   ├── cameras/
│   │   └── camera_extrinsics.json
│   └── metadata.csv
├── generate_sample_list.py
├── inference_syncammaster.py
├── models/
│   └── SynCamMaster/
│       └── checkpoints/
│           └── Put SynCamMaster ckpt file here.txt
├── requirements.txt
├── setup.py
├── train_syncammaster.py
└── vis_cam.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*__pycache__
*.ckpt
Wan-AI

================================================
FILE: README.md
================================================
# SynCamMaster: Synchronizing Multi-Camera Video Generation from Diverse Viewpoints

<div align="center">
<div align="center" style="margin-top: 0px; margin-bottom: 0px;">
<img src=https://github.com/user-attachments/assets/b33c5b67-3881-4fa3-b853-f932eebc9c50 width="50%"/>
</div>

### [<a href="https://arxiv.org/abs/2412.07760" target="_blank">arXiv</a>] [<a href="https://jianhongbai.github.io/SynCamMaster/" target="_blank">Project Page</a>] [<a href="https://huggingface.co/datasets/KwaiVGI/SynCamVideo-Dataset/" target="_blank">Dataset</a>]

_**[Jianhong Bai<sup>1*</sup>](https://jianhongbai.github.io/), [Menghan Xia<sup>2†</sup>](https://menghanxia.github.io/), [Xintao Wang<sup>2</sup>](https://xinntao.github.io/), [Ziyang Yuan<sup>3</sup>](https://scholar.google.ru/citations?user=fWxWEzsAAAAJ&hl=en), [Xiao Fu<sup>4</sup>](https://fuxiao0719.github.io/), <br>[Zuozhu Liu<sup>1</sup>](https://person.zju.edu.cn/en/lzz), [Haoji Hu<sup>1</sup>](https://person.zju.edu.cn/en/huhaoji), [Pengfei Wan<sup>2</sup>](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en), [Di Zhang<sup>2</sup>](https://openreview.net/profile?id=~Di_ZHANG3)**_
<br>
(*Work done during an internship at KwaiVGI, Kuaishou Technology †corresponding author)

<sup>1</sup>Zhejiang University, <sup>2</sup>Kuaishou Technology, <sup>3</sup>Tsinghua University, <sup>4</sup>CUHK.

**ICLR 2025**

</div>

**Important Note:** This open-source repository is intended to provide a reference implementation. Due to the difference in the underlying T2V model's performance, the open-source version may not achieve the same performance as the model in our paper.

## 🔥 Updates
- __[2025.04.15]__: Please feel free to explore our subsequent work, [ReCamMaster](https://github.com/KwaiVGI/ReCamMaster).
- __[2025.04.15]__: Update a new version of the [SynCamVideo Dataset](https://huggingface.co/datasets/KwaiVGI/SynCamVideo-Dataset).
- __[2025.04.15]__: Release the [training and inference code](https://github.com/KwaiVGI/SynCamMaster?tab=readme-ov-file#%EF%B8%8F-code-syncammaster--wan21-inference--training), [model checkpoint](https://huggingface.co/KwaiVGI/SynCamMaster-Wan2.1/blob/main/step20000.ckpt).
- __[2024.12.10]__: Release the [project page](https://jianhongbai.github.io/SynCamMaster/) and the [SynCamVideo Dataset](https://huggingface.co/datasets/KwaiVGI/SynCamVideo-Dataset/).
  
## 📖 Introduction

**TL;DR:** We propose SynCamMaster, an efficient method to lift pre-trained text-to-video models for open-domain multi-camera video generation from diverse viewpoints. We also release a multi-camera synchronized video [dataset](https://huggingface.co/datasets/KwaiVGI/SynCamVideo-Dataset) rendered with Unreal Engine 5. <br>

https://github.com/user-attachments/assets/1ecfaea8-5d87-4bb5-94fc-062f84bd67a1

## ⚙️ Code: SynCamMaster + Wan2.1 (Inference & Training)
The model utilized in our paper is an internally developed T2V model, not [Wan2.1](https://github.com/Wan-Video/Wan2.1). Due to company policy restrictions, we are unable to open-source the model used in the paper. Consequently, we migrated SynCamMaster to Wan2.1 to validate the effectiveness of our method. Due to differences in the underlying T2V model, you may not achieve the same results as demonstrated in the demo.
### Inference
Step 1: Set up the environment

[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) requires Rust and Cargo to compile extensions. You can install them using the following command:
```shell
curl --proto '=https' --tlsv1.2 -sSf [https://sh.rustup.rs](https://sh.rustup.rs/) | sh
. "$HOME/.cargo/env"
```

Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio):
```shell
git clone https://github.com/KwaiVGI/SynCamMaster.git
cd SynCamMaster
pip install -e .
```

Step 2: Download the pretrained checkpoints
1. Download the pre-trained Wan2.1 models

```shell
cd SynCamMaster
python download_wan2.1.py
```
2. Download the pre-trained SynCamMaster checkpoint

Please download from [huggingface](https://huggingface.co/KwaiVGI/SynCamMaster-Wan2.1/blob/main/step20000.ckpt) and place it in ```models/SynCamMaster/checkpoints```.

Step 3: Test the example videos
```shell
python inference_syncammaster.py --cam_type "az"
```

We provide several preset camera types. Additionally, you can generate new camera poses for testing.

### Training

Step 1: Set up the environment

```shell
pip install lightning pandas websockets
```

Step 2: Prepare the training dataset

1. Download the [SynCamVideo dataset](https://huggingface.co/datasets/KwaiVGI/SynCamVideo-Dataset).

2. Extract VAE features

```shell
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_syncammaster.py   --task data_process   --dataset_path path/to/the/SynCamVideo/Dataset   --output_path ./models   --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth"   --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"   --tiled   --num_frames 81   --height 480   --width 832 --dataloader_num_workers 2
```

3. Generate Captions for Each Video

You can use video caption tools like [LLaVA](https://github.com/haotian-liu/LLaVA) to generate captions for each video and store them in the ```metadata.csv``` file.

4. Calculate the availble sample list

```shell
python generate_sample_list.py
```

Step 3: Training
```shell
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_syncammaster.py   --task train   --output_path ./models/train   --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors"   --steps_per_epoch 8000   --max_epochs 100   --learning_rate 1e-4   --accumulate_grad_batches 1   --use_gradient_checkpointing  --dataloader_num_workers 4
```
We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size.

Step 4: Test the model

```shell
python inference_syncammaster.py --cam_type "az" --ckpt_path path/to/the/checkpoint
```

## 📷 Dataset: SynCamVideo Dataset
### 1. Dataset Introduction

**TL;DR:** The SynCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera poses. The SynCamVideo Dataset can be valuable in fields such as camera-controlled video generation, synchronized video production, and 3D/4D reconstruction. The camera is stationary in the SynCamVideo Dataset. If you require footage with moving cameras rather than stationary ones, please explore our [MultiCamVideo](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset) Dataset.

https://github.com/user-attachments/assets/b49fc632-d1df-49fd-93d2-8513fbdb9377

The SynCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera poses.
It consists of 3.4K different dynamic scenes, each captured by 10 cameras, resulting in a total of 34K videos. Each dynamic scene is composed of four elements: {3D environment, character, animation, camera}. Specifically, we use animation to drive the character, 
and position the animated character within the 3D environment. Then, Time-synchronized cameras are set up to render the multi-camera video data.
<p align="center">
  <img src="https://github.com/user-attachments/assets/107c9607-e99b-4493-b715-3e194fcb3933" alt="Example Image" width="70%">
</p>

**3D Environment:** We collect 37 high-quality 3D environments assets from [Fab](https://www.fab.com). To minimize the domain gap between rendered data and real-world videos, we primarily select visually realistic 3D scenes, while choosing a few stylized or surreal 3D scenes as a supplement. To ensure data diversity, the selected scenes cover a variety of indoor and outdoor settings, such as city streets, shopping malls, cafes, office rooms, and the countryside.

**Character:** We collect 66 different human 3D models as characters from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com).

**Animation:** We collect 93 different animations from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com), including common actions such as waving, dancing, and cheering. We use these animations to drive the collected characters and create diverse datasets through various combinations.

**Camera:** To enhance the diversity of the dataset, each camera is randomly sampled on a hemispherical surface centered around the character.

### 2. Statistics and Configurations

Dataset Statistics:

| Number of Dynamic Scenes | Camera per Scene | Total Videos |
|:------------------------:|:----------------:|:------------:|
| 3400                   | 10               | 34,000      |

Video Configurations:

| Resolution  | Frame Number | FPS                      |
|:-----------:|:------------:|:------------------------:|
| 1280x1280   | 81           | 15                       |

Note: You can use 'center crop' to adjust the video's aspect ratio to fit your video generation model, such as 16:9, 9:16, 4:3, or 3:4.

Camera Configurations:

| Focal Length            | Aperture           | Sensor Height | Sensor Width |
|:-----------------------:|:------------------:|:-------------:|:------------:|
| 24mm  | 5.0     | 23.76mm       | 23.76mm      |



### 3. File Structure
```
SynCamVideo-Dataset
├── train
│   └── f24_aperture5
│       ├── scene1    # one dynamic scene
│       │   ├── videos
│       │   │   ├── cam01.mp4    # synchronized 81-frame videos at 1280x1280 resolution
│       │   │   ├── cam02.mp4
│       │   │   ├── ...
│       │   │   └── cam10.mp4
│       │   └── cameras
│       │       └── camera_extrinsics.json    # 81-frame camera extrinsics of the 10 cameras 
│       ├── ...
│       └── scene3400
└── val
    └── basic
        ├── videos
        │   ├── cam01.mp4    # example videos corresponding to the validation cameras
        │   ├── cam02.mp4
        │   ├── ...
        │   └── cam10.mp4
        └── cameras
            └── camera_extrinsics.json    # 10 cameras for validation
```

### 3. Useful scripts
- Data Extraction
```bash
tar -xzvf SynCamVideo-Dataset.tar.gz
```
- Camera Visualization
```python
python vis_cam.py
```

The visualization script is modified from [CameraCtrl](https://github.com/hehao13/CameraCtrl/blob/main/tools/visualize_trajectory.py), thanks to their inspiring work.

<p align="center">
  <img src="https://github.com/user-attachments/assets/2a4e4063-9868-4b7f-8626-1e6d6f611e3f" alt="Example Image" width="40%">
</p>

## 🤗 Awesome Related Works
Feel free to explore these outstanding related works, including but not limited to:

[GCD](https://gcd.cs.columbia.edu/): synthesize large-angle novel viewpoints of 4D dynamic scenes from a monocular video.

[CVD](https://collaborativevideodiffusion.github.io): multi-view video generation with multiple camera trajectories.

[SV4D](https://sv4d.github.io): multi-view consistent dynamic 3D content generation.

Additionally, check out our "MasterFamily" projects:

[ReCamMaster](https://jianhongbai.github.io/ReCamMaster/): re-capture in-the-wild videos with novel camera trajectories.

[3DTrajMaster](http://fuxiao0719.github.io/projects/3dtrajmaster): control multiple entity motions in 3D space (6DoF) for text-to-video generation.

[StyleMaster](https://zixuan-ye.github.io/stylemaster/): enable artistic video generation and translation with reference style image.


## Acknowledgments
We thank Jinwen Cao, Yisong Guo, Haowen Ji, Jichao Wang, and Yi Wang from Kuaishou Technology for their invaluable help in constructing the SynCamVideo-Dataset. We thank [Guanjun Wu](https://guanjunwu.github.io/) and Jiangnan Ye for their help on running 4DGS.

## 🌟 Citation

Please leave us a star 🌟 and cite our paper if you find our work helpful.
```
@article{bai2024syncammaster,
  title={SynCamMaster: Synchronizing Multi-Camera Video Generation from Diverse Viewpoints},
  author={Bai, Jianhong and Xia, Menghan and Wang, Xintao and Yuan, Ziyang and Fu, Xiao and Liu, Zuozhu and Hu, Haoji and Wan, Pengfei and Zhang, Di},
  journal={arXiv preprint arXiv:2412.07760},
  year={2024}
}
```


================================================
FILE: diffsynth/__init__.py
================================================
from .data import *
from .models import *
from .prompters import *
from .schedulers import *
from .pipelines import *
from .controlnets import *


================================================
FILE: diffsynth/configs/__init__.py
================================================


================================================
FILE: diffsynth/configs/model_config.py
================================================
from typing_extensions import Literal, TypeAlias

from ..models.sd_text_encoder import SDTextEncoder
from ..models.sd_unet import SDUNet
from ..models.sd_vae_encoder import SDVAEEncoder
from ..models.sd_vae_decoder import SDVAEDecoder

from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from ..models.sdxl_unet import SDXLUNet
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder

from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from ..models.sd3_dit import SD3DiT
from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder

from ..models.sd_controlnet import SDControlNet
from ..models.sdxl_controlnet import SDXLControlNetUnion

from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel

from ..models.svd_image_encoder import SVDImageEncoder
from ..models.svd_unet import SVDUNet
from ..models.svd_vae_decoder import SVDVAEDecoder
from ..models.svd_vae_encoder import SVDVAEEncoder

from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder

from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT

from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter

from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT

from ..models.omnigen import OmniGenTransformer

from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder

from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet

from ..models.hunyuan_video_dit import HunyuanVideoDiT

from ..models.stepvideo_vae import StepVideoVAE
from ..models.stepvideo_dit import StepVideoModel

from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE


model_loader_configs = [
    # These configs are provided for detecting model type automatically.
    # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
    (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
    (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
    (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
    (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
    (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
    (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
    (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
    (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
    (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
    (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
    (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
    (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
    (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
    (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
    (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
    (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
    (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
    (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
    (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
    (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
    (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
    (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
    (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
    (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
    (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
    (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
    (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
    (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
    (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
    (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
    (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
    (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
    (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
    (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
    (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
    (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
    (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
    (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
    (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
    (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
    (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
    (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
    (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
    (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
    (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
    (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
    (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
    (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
    (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
    (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
    (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
    (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
    (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
    (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
    (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
    (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
    (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
    (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
    (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
]
huggingface_model_loader_configs = [
    # These configs are provided for detecting model type automatically.
    # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
    ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
    ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
    ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
    ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
    # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
    ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
    ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
    ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
    ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
    ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
    ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
]
patch_model_loader_configs = [
    # These configs are provided for detecting model type automatically.
    # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
    ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
]

preset_models_on_huggingface = {
    "HunyuanDiT": [
        ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
        ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
        ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
        ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
    ],
    "stable-video-diffusion-img2vid-xt": [
        ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
    ],
    "ExVideo-SVD-128f-v1": [
        ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
    ],
    # Stable Diffusion
    "StableDiffusion_v15": [
        ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
    ],
    "DreamShaper_8": [
        ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
    ],
    # Textual Inversion
    "TextualInversion_VeryBadImageNegative_v1.3": [
        ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
    ],
    # Stable Diffusion XL
    "StableDiffusionXL_v1": [
        ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
    ],
    "BluePencilXL_v200": [
        ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
    ],
    "StableDiffusionXL_Turbo": [
        ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
    ],
    # Stable Diffusion 3
    "StableDiffusion3": [
        ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
    ],
    "StableDiffusion3_without_T5": [
        ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
    ],
    # ControlNet
    "ControlNet_v11f1p_sd15_depth": [
        ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
        ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
    ],
    "ControlNet_v11p_sd15_softedge": [
        ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
        ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
    ],
    "ControlNet_v11f1e_sd15_tile": [
        ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
    ],
    "ControlNet_v11p_sd15_lineart": [
        ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
        ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
        ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
    ],
    "ControlNet_union_sdxl_promax": [
        ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
        ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
    ],
    # AnimateDiff
    "AnimateDiff_v2": [
        ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
    ],
    "AnimateDiff_xl_beta": [
        ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
    ],

    # Qwen Prompt
    "QwenPrompt": [
        ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
    ],
    # Beautiful Prompt
    "BeautifulPrompt": [
        ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
        ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
        ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
        ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
        ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
        ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
    ],
    # Omost prompt
    "OmostPrompt":[
        ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),  
        ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
    ],
    # Translator
    "opus-mt-zh-en": [
        ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
        ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
    ],
    # IP-Adapter
    "IP-Adapter-SD": [
        ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
        ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
    ],
    "IP-Adapter-SDXL": [
        ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
        ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
    ],
    "SDXL-vae-fp16-fix": [
        ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
    ],
    # Kolors
    "Kolors": [
        ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
        ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
        ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
    ],
    # FLUX
    "FLUX.1-dev": [
        ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
        ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
        ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
        ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
        ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
        ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
        ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
    ],
    "InstantX/FLUX.1-dev-IP-Adapter": {
        "file_list": [
            ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
            ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
            ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
        ],
        "load_path": [
            "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
            "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
        ],
    },
    # RIFE
    "RIFE": [
        ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
    ],
    # CogVideo
    "CogVideoX-5B": [
        ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
        ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
        ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
        ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
        ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
        ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
        ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
        ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
        ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
    ],
    # Stable Diffusion 3.5
    "StableDiffusion3.5-large": [
        ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
        ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
    ],
}
preset_models_on_modelscope = {
    # Hunyuan DiT
    "HunyuanDiT": [
        ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
        ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
        ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
        ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
    ],
    # Stable Video Diffusion
    "stable-video-diffusion-img2vid-xt": [
        ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
    ],
    # ExVideo
    "ExVideo-SVD-128f-v1": [
        ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
    ],
    "ExVideo-CogVideoX-LoRA-129f-v1": [
        ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
    ],
    # Stable Diffusion
    "StableDiffusion_v15": [
        ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
    ],
    "DreamShaper_8": [
        ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
    ],
    "AingDiffusion_v12": [
        ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
    ],
    "Flat2DAnimerge_v45Sharp": [
        ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
    ],
    # Textual Inversion
    "TextualInversion_VeryBadImageNegative_v1.3": [
        ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
    ],
    # Stable Diffusion XL
    "StableDiffusionXL_v1": [
        ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
    ],
    "BluePencilXL_v200": [
        ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
    ],
    "StableDiffusionXL_Turbo": [
        ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
    ],
    "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
        ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
    ],
    # Stable Diffusion 3
    "StableDiffusion3": [
        ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
    ],
    "StableDiffusion3_without_T5": [
        ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
    ],
    # ControlNet
    "ControlNet_v11f1p_sd15_depth": [
        ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
        ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
    ],
    "ControlNet_v11p_sd15_softedge": [
        ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
        ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
    ],
    "ControlNet_v11f1e_sd15_tile": [
        ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
    ],
    "ControlNet_v11p_sd15_lineart": [
        ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
        ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
        ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
    ],
    "ControlNet_union_sdxl_promax": [
        ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
        ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
    ],
    "Annotators:Depth": [
        ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
    ],
    "Annotators:Softedge": [
        ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
    ],
    "Annotators:Lineart": [
        ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
        ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
    ],
    "Annotators:Normal": [
        ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
    ],
    "Annotators:Openpose": [
        ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
        ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
        ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
    ],
    # AnimateDiff
    "AnimateDiff_v2": [
        ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
    ],
    "AnimateDiff_xl_beta": [
        ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
    ],
    # RIFE
    "RIFE": [
        ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
    ],
    # Qwen Prompt
    "QwenPrompt": {
        "file_list": [
            ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
            ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
        ],
        "load_path": [
            "models/QwenPrompt/qwen2-1.5b-instruct",
        ],
    },
    # Beautiful Prompt
    "BeautifulPrompt": {
        "file_list": [
            ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
            ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
            ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
            ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
            ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
            ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
        ],
        "load_path": [
            "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
        ],
    },
    # Omost prompt
    "OmostPrompt": {
        "file_list": [
            ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
            ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
            ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
            ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),  
            ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
            ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
            ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
            ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
        ],
        "load_path": [
            "models/OmostPrompt/omost-llama-3-8b-4bits",
        ],
    },
    # Translator
    "opus-mt-zh-en": {
        "file_list": [
            ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
            ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
        ],
        "load_path": [
            "models/translator/opus-mt-zh-en",
        ],
    },
    # IP-Adapter
    "IP-Adapter-SD": [
        ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
        ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
    ],
    "IP-Adapter-SDXL": [
        ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
        ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
    ],
    # Kolors
    "Kolors": {
        "file_list": [
            ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
            ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
            ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
        ],
        "load_path": [
            "models/kolors/Kolors/text_encoder",
            "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
            "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
        ],
    },
    "SDXL-vae-fp16-fix": [
        ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
    ],
    # FLUX
    "FLUX.1-dev": {
        "file_list": [
            ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
            ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
        ],
        "load_path": [
            "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
            "models/FLUX/FLUX.1-dev/text_encoder_2",
            "models/FLUX/FLUX.1-dev/ae.safetensors",
            "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
        ],
    },
    "FLUX.1-schnell": {
        "file_list": [
            ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
            ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
            ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
        ],
        "load_path": [
            "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
            "models/FLUX/FLUX.1-dev/text_encoder_2",
            "models/FLUX/FLUX.1-dev/ae.safetensors",
            "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
        ],
    },
    "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
        ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
    ],
    "jasperai/Flux.1-dev-Controlnet-Depth": [
        ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
    ],
    "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
        ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
    ],
    "jasperai/Flux.1-dev-Controlnet-Upscaler": [
        ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
    ],
    "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
        ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
    ],
    "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
        ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
    ],
    "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
        ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
    ],
    "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
        ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
    ],
    "InstantX/FLUX.1-dev-IP-Adapter": {
        "file_list": [
            ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
            ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
            ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
        ],
        "load_path": [
            "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
            "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
        ],
    },
    # ESRGAN
    "ESRGAN_x4": [
        ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
    ],
    # RIFE
    "RIFE": [
        ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
    ],
    # Omnigen
    "OmniGen-v1": {
        "file_list": [
            ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
            ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
            ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
            ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
            ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
            ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
        ],
        "load_path": [
            "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
            "models/OmniGen/OmniGen-v1/model.safetensors",
        ]
    },
    # CogVideo
    "CogVideoX-5B": {
        "file_list": [
            ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
            ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
            ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
            ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
            ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
            ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
            ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
            ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
            ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
        ],
        "load_path": [
            "models/CogVideo/CogVideoX-5b/text_encoder",
            "models/CogVideo/CogVideoX-5b/transformer",
            "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
        ],
    },
    # Stable Diffusion 3.5
    "StableDiffusion3.5-large": [
        ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
    ],
    "StableDiffusion3.5-medium": [
        ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
    ],
    "StableDiffusion3.5-large-turbo": [
        ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
        ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
    ],
    "HunyuanVideo":{
        "file_list": [
            ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
            ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
            ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
        ],
        "load_path": [
            "models/HunyuanVideo/text_encoder/model.safetensors",
            "models/HunyuanVideo/text_encoder_2",
            "models/HunyuanVideo/vae/pytorch_model.pt",
            "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
        ],
    },
    "HunyuanVideoI2V":{
        "file_list": [
            ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
            ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
            ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
            ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
            ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
            ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
            ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
            ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
            ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
        ],
        "load_path": [
            "models/HunyuanVideoI2V/text_encoder/model.safetensors",
            "models/HunyuanVideoI2V/text_encoder_2",
            "models/HunyuanVideoI2V/vae/pytorch_model.pt",
            "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
        ],
    },
    "HunyuanVideo-fp8":{
        "file_list": [
            ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
            ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
            ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
            ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
        ],
        "load_path": [
            "models/HunyuanVideo/text_encoder/model.safetensors",
            "models/HunyuanVideo/text_encoder_2",
            "models/HunyuanVideo/vae/pytorch_model.pt",
            "models/HunyuanVideo/transformers/model.fp8.safetensors"
        ],
    },
}
Preset_model_id: TypeAlias = Literal[
    "HunyuanDiT",
    "stable-video-diffusion-img2vid-xt",
    "ExVideo-SVD-128f-v1",
    "ExVideo-CogVideoX-LoRA-129f-v1",
    "StableDiffusion_v15",
    "DreamShaper_8",
    "AingDiffusion_v12",
    "Flat2DAnimerge_v45Sharp",
    "TextualInversion_VeryBadImageNegative_v1.3",
    "StableDiffusionXL_v1",
    "BluePencilXL_v200",
    "StableDiffusionXL_Turbo",
    "ControlNet_v11f1p_sd15_depth",
    "ControlNet_v11p_sd15_softedge",
    "ControlNet_v11f1e_sd15_tile",
    "ControlNet_v11p_sd15_lineart",
    "AnimateDiff_v2",
    "AnimateDiff_xl_beta",
    "RIFE",
    "BeautifulPrompt",
    "opus-mt-zh-en",
    "IP-Adapter-SD",
    "IP-Adapter-SDXL",
    "StableDiffusion3",
    "StableDiffusion3_without_T5",
    "Kolors",
    "SDXL-vae-fp16-fix",
    "ControlNet_union_sdxl_promax",
    "FLUX.1-dev",
    "FLUX.1-schnell",
    "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
    "jasperai/Flux.1-dev-Controlnet-Depth",
    "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
    "jasperai/Flux.1-dev-Controlnet-Upscaler",
    "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
    "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
    "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
    "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
    "InstantX/FLUX.1-dev-IP-Adapter",
    "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
    "QwenPrompt",
    "OmostPrompt",
    "ESRGAN_x4",
    "RIFE",
    "OmniGen-v1",
    "CogVideoX-5B",
    "Annotators:Depth",
    "Annotators:Softedge",
    "Annotators:Lineart",
    "Annotators:Normal",
    "Annotators:Openpose",
    "StableDiffusion3.5-large",
    "StableDiffusion3.5-medium",
    "HunyuanVideo",
    "HunyuanVideo-fp8",
    "HunyuanVideoI2V",
]


================================================
FILE: diffsynth/controlnets/__init__.py
================================================
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
from .processors import Annotator


================================================
FILE: diffsynth/controlnets/controlnet_unit.py
================================================
import torch
import numpy as np
from .processors import Processor_id


class ControlNetConfigUnit:
    def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
        self.processor_id = processor_id
        self.model_path = model_path
        self.scale = scale
        self.skip_processor = skip_processor


class ControlNetUnit:
    def __init__(self, processor, model, scale=1.0):
        self.processor = processor
        self.model = model
        self.scale = scale


class MultiControlNetManager:
    def __init__(self, controlnet_units=[]):
        self.processors = [unit.processor for unit in controlnet_units]
        self.models = [unit.model for unit in controlnet_units]
        self.scales = [unit.scale for unit in controlnet_units]

    def cpu(self):
        for model in self.models:
            model.cpu()

    def to(self, device):
        for model in self.models:
            model.to(device)
        for processor in self.processors:
            processor.to(device)
    
    def process_image(self, image, processor_id=None):
        if processor_id is None:
            processed_image = [processor(image) for processor in self.processors]
        else:
            processed_image = [self.processors[processor_id](image)]
        processed_image = torch.concat([
            torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
            for image_ in processed_image
        ], dim=0)
        return processed_image
    
    def __call__(
        self,
        sample, timestep, encoder_hidden_states, conditionings,
        tiled=False, tile_size=64, tile_stride=32, **kwargs
    ):
        res_stack = None
        for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
            res_stack_ = model(
                sample, timestep, encoder_hidden_states, conditioning, **kwargs,
                tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
                processor_id=processor.processor_id
            )
            res_stack_ = [res * scale for res in res_stack_]
            if res_stack is None:
                res_stack = res_stack_
            else:
                res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
        return res_stack


class FluxMultiControlNetManager(MultiControlNetManager):
    def __init__(self, controlnet_units=[]):
        super().__init__(controlnet_units=controlnet_units)

    def process_image(self, image, processor_id=None):
        if processor_id is None:
            processed_image = [processor(image) for processor in self.processors]
        else:
            processed_image = [self.processors[processor_id](image)]
        return processed_image

    def __call__(self, conditionings, **kwargs):
        res_stack, single_res_stack = None, None
        for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
            res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
            res_stack_ = [res * scale for res in res_stack_]
            single_res_stack_ = [res * scale for res in single_res_stack_]
            if res_stack is None:
                res_stack = res_stack_
                single_res_stack = single_res_stack_
            else:
                res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
                single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
        return res_stack, single_res_stack


================================================
FILE: diffsynth/controlnets/processors.py
================================================
from typing_extensions import Literal, TypeAlias


Processor_id: TypeAlias = Literal[
    "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]

class Annotator:
    def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
        if not skip_processor:
            if processor_id == "canny":
                from controlnet_aux.processor import CannyDetector
                self.processor = CannyDetector()
            elif processor_id == "depth":
                from controlnet_aux.processor import MidasDetector
                self.processor = MidasDetector.from_pretrained(model_path).to(device)
            elif processor_id == "softedge":
                from controlnet_aux.processor import HEDdetector
                self.processor = HEDdetector.from_pretrained(model_path).to(device)
            elif processor_id == "lineart":
                from controlnet_aux.processor import LineartDetector
                self.processor = LineartDetector.from_pretrained(model_path).to(device)
            elif processor_id == "lineart_anime":
                from controlnet_aux.processor import LineartAnimeDetector
                self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
            elif processor_id == "openpose":
                from controlnet_aux.processor import OpenposeDetector
                self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
            elif processor_id == "normal":
                from controlnet_aux.processor import NormalBaeDetector
                self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
            elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
                self.processor = None
            else:
                raise ValueError(f"Unsupported processor_id: {processor_id}")
        else:
            self.processor = None

        self.processor_id = processor_id
        self.detect_resolution = detect_resolution
    
    def to(self,device):
        if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):

            self.processor.model.to(device)

    def __call__(self, image, mask=None):
        width, height = image.size
        if self.processor_id == "openpose":
            kwargs = {
                "include_body": True,
                "include_hand": True,
                "include_face": True
            }
        else:
            kwargs = {}
        if self.processor is not None:
            detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
            image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
        image = image.resize((width, height))
        return image



================================================
FILE: diffsynth/data/__init__.py
================================================
from .video import VideoData, save_video, save_frames


================================================
FILE: diffsynth/data/simple_text_image.py
================================================
import torch, os, torchvision
from torchvision import transforms
import pandas as pd
from PIL import Image



class TextImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
        self.steps_per_epoch = steps_per_epoch
        metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
        self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
        self.text = metadata["text"].to_list()
        self.height = height
        self.width = width
        self.image_processor = transforms.Compose(
            [
                transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
                transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )


    def __getitem__(self, index):
        data_id = torch.randint(0, len(self.path), (1,))[0]
        data_id = (data_id + index) % len(self.path) # For fixed seed.
        text = self.text[data_id]
        image = Image.open(self.path[data_id]).convert("RGB")
        target_height, target_width = self.height, self.width
        width, height = image.size
        scale = max(target_width / width, target_height / height)
        shape = [round(height*scale),round(width*scale)]
        image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
        image = self.image_processor(image)
        return {"text": text, "image": image}


    def __len__(self):
        return self.steps_per_epoch


================================================
FILE: diffsynth/data/video.py
================================================
import imageio, os
import numpy as np
from PIL import Image
from tqdm import tqdm


class LowMemoryVideo:
    def __init__(self, file_name):
        self.reader = imageio.get_reader(file_name)
    
    def __len__(self):
        return self.reader.count_frames()

    def __getitem__(self, item):
        return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")

    def __del__(self):
        self.reader.close()


def split_file_name(file_name):
    result = []
    number = -1
    for i in file_name:
        if ord(i)>=ord("0") and ord(i)<=ord("9"):
            if number == -1:
                number = 0
            number = number*10 + ord(i) - ord("0")
        else:
            if number != -1:
                result.append(number)
                number = -1
            result.append(i)
    if number != -1:
        result.append(number)
    result = tuple(result)
    return result


def search_for_images(folder):
    file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
    file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
    file_list = [i[1] for i in sorted(file_list)]
    file_list = [os.path.join(folder, i) for i in file_list]
    return file_list


class LowMemoryImageFolder:
    def __init__(self, folder, file_list=None):
        if file_list is None:
            self.file_list = search_for_images(folder)
        else:
            self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
    
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, item):
        return Image.open(self.file_list[item]).convert("RGB")

    def __del__(self):
        pass


def crop_and_resize(image, height, width):
    image = np.array(image)
    image_height, image_width, _ = image.shape
    if image_height / image_width < height / width:
        croped_width = int(image_height / height * width)
        left = (image_width - croped_width) // 2
        image = image[:, left: left+croped_width]
        image = Image.fromarray(image).resize((width, height))
    else:
        croped_height = int(image_width / width * height)
        left = (image_height - croped_height) // 2
        image = image[left: left+croped_height, :]
        image = Image.fromarray(image).resize((width, height))
    return image


class VideoData:
    def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
        if video_file is not None:
            self.data_type = "video"
            self.data = LowMemoryVideo(video_file, **kwargs)
        elif image_folder is not None:
            self.data_type = "images"
            self.data = LowMemoryImageFolder(image_folder, **kwargs)
        else:
            raise ValueError("Cannot open video or image folder")
        self.length = None
        self.set_shape(height, width)

    def raw_data(self):
        frames = []
        for i in range(self.__len__()):
            frames.append(self.__getitem__(i))
        return frames

    def set_length(self, length):
        self.length = length

    def set_shape(self, height, width):
        self.height = height
        self.width = width

    def __len__(self):
        if self.length is None:
            return len(self.data)
        else:
            return self.length

    def shape(self):
        if self.height is not None and self.width is not None:
            return self.height, self.width
        else:
            height, width, _ = self.__getitem__(0).shape
            return height, width

    def __getitem__(self, item):
        frame = self.data.__getitem__(item)
        width, height = frame.size
        if self.height is not None and self.width is not None:
            if self.height != height or self.width != width:
                frame = crop_and_resize(frame, self.height, self.width)
        return frame

    def __del__(self):
        pass

    def save_images(self, folder):
        os.makedirs(folder, exist_ok=True)
        for i in tqdm(range(self.__len__()), desc="Saving images"):
            frame = self.__getitem__(i)
            frame.save(os.path.join(folder, f"{i}.png"))


def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
    writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
    for frame in tqdm(frames, desc="Saving video"):
        frame = np.array(frame)
        writer.append_data(frame)
    writer.close()

def save_frames(frames, save_path):
    os.makedirs(save_path, exist_ok=True)
    for i, frame in enumerate(tqdm(frames, desc="Saving images")):
        frame.save(os.path.join(save_path, f"{i}.png"))


================================================
FILE: diffsynth/extensions/ESRGAN/__init__.py
================================================
import torch
from einops import repeat
from PIL import Image
import numpy as np


class ResidualDenseBlock(torch.nn.Module):

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x


class RRDB(torch.nn.Module):

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * 0.2 + x


class RRDBNet(torch.nn.Module):

    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
        super(RRDBNet, self).__init__()
        self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
        self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # upsample
        self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
        self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        feat = x
        feat = self.conv_first(feat)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        # upsample
        feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
        feat = self.lrelu(self.conv_up1(feat))
        feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
        feat = self.lrelu(self.conv_up2(feat))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out
    
    @staticmethod
    def state_dict_converter():
        return RRDBNetStateDictConverter()
    

class RRDBNetStateDictConverter:
    def __init__(self):
        pass

    def from_diffusers(self, state_dict):
        return state_dict, {"upcast_to_float32": True}
    
    def from_civitai(self, state_dict):
        return state_dict, {"upcast_to_float32": True}


class ESRGAN(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    @staticmethod
    def from_model_manager(model_manager):
        return ESRGAN(model_manager.fetch_model("esrgan"))

    def process_image(self, image):
        image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
        return image
    
    def process_images(self, images):
        images = [self.process_image(image) for image in images]
        images = torch.stack(images)
        return images
    
    def decode_images(self, images):
        images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
        images = [Image.fromarray(image) for image in images]
        return images
    
    @torch.no_grad()
    def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
        if not isinstance(images, list):
            images = [images]
            is_single_image = True
        else:
            is_single_image = False

        # Preprocess
        input_tensor = self.process_images(images)

        # Interpolate
        output_tensor = []
        for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
            batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
            batch_input_tensor = input_tensor[batch_id: batch_id_]
            batch_input_tensor = batch_input_tensor.to(
                device=self.model.conv_first.weight.device,
                dtype=self.model.conv_first.weight.dtype)
            batch_output_tensor = self.model(batch_input_tensor)
            output_tensor.append(batch_output_tensor.cpu())
        
        # Output
        output_tensor = torch.concat(output_tensor, dim=0)

        # To images
        output_images = self.decode_images(output_tensor)
        if is_single_image:
            output_images = output_images[0]
        return output_images


================================================
FILE: diffsynth/extensions/FastBlend/__init__.py
================================================
from .runners.fast import TableManager, PyramidPatchMatcher
from PIL import Image
import numpy as np
import cupy as cp


class FastBlendSmoother:
    def __init__(self):
        self.batch_size = 8
        self.window_size = 64
        self.ebsynth_config = {
            "minimum_patch_size": 5,
            "threads_per_block": 8,
            "num_iter": 5,
            "gpu_id": 0,
            "guide_weight": 10.0,
            "initialize": "identity",
            "tracking_window_size": 0,
        }

    @staticmethod
    def from_model_manager(model_manager):
        # TODO: fetch GPU ID from model_manager
        return FastBlendSmoother()

    def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
        frames_guide = [np.array(frame) for frame in frames_guide]
        frames_style = [np.array(frame) for frame in frames_style]
        table_manager = TableManager()
        patch_match_engine = PyramidPatchMatcher(
            image_height=frames_style[0].shape[0],
            image_width=frames_style[0].shape[1],
            channel=3,
            **ebsynth_config
        )
        # left part
        table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
        table_l = table_manager.remapping_table_to_blending_table(table_l)
        table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
        # right part
        table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
        table_r = table_manager.remapping_table_to_blending_table(table_r)
        table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
        # merge
        frames = []
        for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
            weight_m = -1
            weight = weight_l + weight_m + weight_r
            frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
            frames.append(frame)
        frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
        return frames
    
    def __call__(self, rendered_frames, original_frames=None, **kwargs):
        frames = self.run(
            original_frames, rendered_frames,
            self.batch_size, self.window_size, self.ebsynth_config
        )
        mempool = cp.get_default_memory_pool()
        pinned_mempool = cp.get_default_pinned_memory_pool()
        mempool.free_all_blocks()
        pinned_mempool.free_all_blocks()
        return frames

================================================
FILE: diffsynth/extensions/FastBlend/api.py
================================================
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
from .data import VideoData, get_video_fps, save_video, search_for_images
import os
import gradio as gr


def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
    frames_guide = VideoData(video_guide, video_guide_folder)
    frames_style = VideoData(video_style, video_style_folder)
    message = ""
    if len(frames_guide) < len(frames_style):
        message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
        frames_style.set_length(len(frames_guide))
    elif len(frames_guide) > len(frames_style):
        message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
        frames_guide.set_length(len(frames_style))
    height_guide, width_guide = frames_guide.shape()
    height_style, width_style = frames_style.shape()
    if height_guide != height_style or width_guide != width_style:
        message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
        frames_style.set_shape(height_guide, width_guide)
    return frames_guide, frames_style, message


def smooth_video(
    video_guide,
    video_guide_folder,
    video_style,
    video_style_folder,
    mode,
    window_size,
    batch_size,
    tracking_window_size,
    output_path,
    fps,
    minimum_patch_size,
    num_iter,
    guide_weight,
    initialize,
    progress = None,
):
    # input
    frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
    if len(message) > 0:
        print(message)
    # output
    if output_path == "":
        if video_style is None:
            output_path = os.path.join(video_style_folder, "output")
        else:
            output_path = os.path.join(os.path.split(video_style)[0], "output")
        os.makedirs(output_path, exist_ok=True)
        print("No valid output_path. Your video will be saved here:", output_path)
    elif not os.path.exists(output_path):
        os.makedirs(output_path, exist_ok=True)
        print("Your video will be saved here:", output_path)
    frames_path = os.path.join(output_path, "frames")
    video_path = os.path.join(output_path, "video.mp4")
    os.makedirs(frames_path, exist_ok=True)
    # process
    if mode == "Fast" or mode == "Balanced":
        tracking_window_size = 0
    ebsynth_config = {
        "minimum_patch_size": minimum_patch_size,
        "threads_per_block": 8,
        "num_iter": num_iter,
        "gpu_id": 0,
        "guide_weight": guide_weight,
        "initialize": initialize,
        "tracking_window_size": tracking_window_size,
    }
    if mode == "Fast":
        FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
    elif mode == "Balanced":
        BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
    elif mode == "Accurate":
        AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
    # output
    try:
        fps = int(fps)
    except:
        fps = get_video_fps(video_style) if video_style is not None else 30
    print("Fps:", fps)
    print("Saving video...")
    video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
    print("Success!")
    print("Your frames are here:", frames_path)
    print("Your video is here:", video_path)
    return output_path, fps, video_path


class KeyFrameMatcher:
    def __init__(self):
        pass

    def extract_number_from_filename(self, file_name):
        result = []
        number = -1
        for i in file_name:
            if ord(i)>=ord("0") and ord(i)<=ord("9"):
                if number == -1:
                    number = 0
                number = number*10 + ord(i) - ord("0")
            else:
                if number != -1:
                    result.append(number)
                    number = -1
        if number != -1:
            result.append(number)
        result = tuple(result)
        return result

    def extract_number_from_filenames(self, file_names):
        numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
        min_length = min(len(i) for i in numbers)
        for i in range(min_length-1, -1, -1):
            if len(set(number[i] for number in numbers))==len(file_names):
                return [number[i] for number in numbers]
        return list(range(len(file_names)))

    def match_using_filename(self, file_names_a, file_names_b):
        file_names_b_set = set(file_names_b)
        matched_file_name = []
        for file_name in file_names_a:
            if file_name not in file_names_b_set:
                matched_file_name.append(None)
            else:
                matched_file_name.append(file_name)
        return matched_file_name

    def match_using_numbers(self, file_names_a, file_names_b):
        numbers_a = self.extract_number_from_filenames(file_names_a)
        numbers_b = self.extract_number_from_filenames(file_names_b)
        numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
        matched_file_name = []
        for number in numbers_a:
            if number in numbers_b_dict:
                matched_file_name.append(numbers_b_dict[number])
            else:
                matched_file_name.append(None)
        return matched_file_name

    def match_filenames(self, file_names_a, file_names_b):
        matched_file_name = self.match_using_filename(file_names_a, file_names_b)
        if sum([i is not None for i in matched_file_name]) > 0:
            return matched_file_name
        matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
        return matched_file_name


def detect_frames(frames_path, keyframes_path):
    if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
        return "Please input the directory of guide video and rendered frames"
    elif not os.path.exists(frames_path):
        return "Please input the directory of guide video"
    elif not os.path.exists(keyframes_path):
        return "Please input the directory of rendered frames"
    frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
    keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
    if len(frames)==0:
        return f"No images detected in {frames_path}"
    if len(keyframes)==0:
        return f"No images detected in {keyframes_path}"
    matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
    max_filename_length = max([len(i) for i in frames])
    if sum([i is not None for i in matched_keyframes])==0:
        message = ""
        for frame, matched_keyframe in zip(frames, matched_keyframes):
            message += frame + " " * (max_filename_length - len(frame) + 1)
            message += "--> No matched keyframes\n"
    else:
        message = ""
        for frame, matched_keyframe in zip(frames, matched_keyframes):
            message += frame + " " * (max_filename_length - len(frame) + 1)
            if matched_keyframe is None:
                message += "--> [to be rendered]\n"
            else:
                message += f"--> {matched_keyframe}\n"
    return message


def check_input_for_interpolating(frames_path, keyframes_path):
    # search for images
    frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
    keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
    # match frames
    matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
    file_list = [file_name for file_name in matched_keyframes if file_name is not None]
    index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
    frames_guide = VideoData(None, frames_path)
    frames_style = VideoData(None, keyframes_path, file_list=file_list)
    # match shape
    message = ""
    height_guide, width_guide = frames_guide.shape()
    height_style, width_style = frames_style.shape()
    if height_guide != height_style or width_guide != width_style:
        message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
        frames_style.set_shape(height_guide, width_guide)
    return frames_guide, frames_style, index_style, message


def interpolate_video(
    frames_path,
    keyframes_path,
    output_path,
    fps,
    batch_size,
    tracking_window_size,
    minimum_patch_size,
    num_iter,
    guide_weight,
    initialize,
    progress = None,
):
    # input
    frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
    if len(message) > 0:
        print(message)
    # output
    if output_path == "":
        output_path = os.path.join(keyframes_path, "output")
        os.makedirs(output_path, exist_ok=True)
        print("No valid output_path. Your video will be saved here:", output_path)
    elif not os.path.exists(output_path):
        os.makedirs(output_path, exist_ok=True)
        print("Your video will be saved here:", output_path)
    output_frames_path = os.path.join(output_path, "frames")
    output_video_path = os.path.join(output_path, "video.mp4")
    os.makedirs(output_frames_path, exist_ok=True)
    # process
    ebsynth_config = {
        "minimum_patch_size": minimum_patch_size,
        "threads_per_block": 8,
        "num_iter": num_iter,
        "gpu_id": 0,
        "guide_weight": guide_weight,
        "initialize": initialize,
        "tracking_window_size": tracking_window_size
    }
    if len(index_style)==1:
        InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
    else:
        InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
    try:
        fps = int(fps)
    except:
        fps = 30
    print("Fps:", fps)
    print("Saving video...")
    video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
    print("Success!")
    print("Your frames are here:", output_frames_path)
    print("Your video is here:", video_path)
    return output_path, fps, video_path


def on_ui_tabs():
    with gr.Blocks(analytics_enabled=False) as ui_component:
        with gr.Tab("Blend"):
            gr.Markdown("""
# Blend

Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
            """)
            with gr.Row():
                with gr.Column():
                    with gr.Tab("Guide video"):
                        video_guide = gr.Video(label="Guide video")
                    with gr.Tab("Guide video (images format)"):
                        video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
                with gr.Column():
                    with gr.Tab("Style video"):
                        video_style = gr.Video(label="Style video")
                    with gr.Tab("Style video (images format)"):
                        video_style_folder = gr.Textbox(label="Style video (images format)", value="")
                with gr.Column():
                    output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
                    fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
                    video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
            btn = gr.Button(value="Blend")
            with gr.Row():
                with gr.Column():
                    gr.Markdown("# Settings")
                    mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
                    window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
                    batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
                    tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
                    gr.Markdown("## Advanced Settings")
                    minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
                    num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
                    guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
                    initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
                with gr.Column():
                    gr.Markdown("""
# Reference

* Output directory: the directory to save the video.
* Inference mode

|Mode|Time|Memory|Quality|Frame by frame output|Description|
|-|-|-|-|-|-|
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|

* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
    * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
    * Number of iterations: the number of iterations of patch matching. (Default: 5)
    * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
    * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
                    """)
            btn.click(
                smooth_video,
                inputs=[
                    video_guide,
                    video_guide_folder,
                    video_style,
                    video_style_folder,
                    mode,
                    window_size,
                    batch_size,
                    tracking_window_size,
                    output_path,
                    fps,
                    minimum_patch_size,
                    num_iter,
                    guide_weight,
                    initialize
                ],
                outputs=[output_path, fps, video_output]
            )
        with gr.Tab("Interpolate"):
            gr.Markdown("""
# Interpolate

Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
            """)
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        with gr.Column():
                            video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
                        with gr.Column():
                            rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
                    with gr.Row():
                        detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
                    video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
                    rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
                with gr.Column():
                    output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
                    fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
                    video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
            btn_ = gr.Button(value="Interpolate")
            with gr.Row():
                with gr.Column():
                    gr.Markdown("# Settings")
                    batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
                    tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
                    gr.Markdown("## Advanced Settings")
                    minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
                    num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
                    guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
                    initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
                with gr.Column():
                    gr.Markdown("""
# Reference

* Output directory: the directory to save the video.
* Batch size: a larger batch size makes the program faster but requires more VRAM.
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
* Advanced settings
    * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
    * Number of iterations: the number of iterations of patch matching. (Default: 5)
    * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
    * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
                    """)
            btn_.click(
                interpolate_video,
                inputs=[
                    video_guide_folder_,
                    rendered_keyframes_,
                    output_path_,
                    fps_,
                    batch_size_,
                    tracking_window_size_,
                    minimum_patch_size_,
                    num_iter_,
                    guide_weight_,
                    initialize_,
                ],
                outputs=[output_path_, fps_, video_output_]
            )

        return [(ui_component, "FastBlend", "FastBlend_ui")]


================================================
FILE: diffsynth/extensions/FastBlend/cupy_kernels.py
================================================
import cupy as cp

remapping_kernel = cp.RawKernel(r'''
extern "C" __global__
void remap(
    const int height,
    const int width,
    const int channel,
    const int patch_size,
    const int pad_size,
    const float* source_style,
    const int* nnf,
    float* target_style
) {
    const int r = (patch_size - 1) / 2;
    const int x = blockDim.x * blockIdx.x + threadIdx.x;
    const int y = blockDim.y * blockIdx.y + threadIdx.y;
    if (x >= height or y >= width) return;
    const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
    const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
    const int min_px = x < r ? -x : -r;
    const int max_px = x + r > height - 1 ? height - 1 - x : r;
    const int min_py = y < r ? -y : -r;
    const int max_py = y + r > width - 1 ? width - 1 - y : r;
    int num = 0;
    for (int px = min_px; px <= max_px; px++){
        for (int py = min_py; py <= max_py; py++){
            const int nid = (x + px) * width + y + py;
            const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
            const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
            if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
            const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
            num++;
            for (int c = 0; c < channel; c++){
                target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
            }
        }
    }
    for (int c = 0; c < channel; c++){
        target_style[z + pid * channel + c] /= num;
    }
}
''', 'remap')


patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void patch_error(
    const int height,
    const int width,
    const int channel,
    const int patch_size,
    const int pad_size,
    const float* source,
    const int* nnf,
    const float* target,
    float* error
) {
    const int r = (patch_size - 1) / 2;
    const int x = blockDim.x * blockIdx.x + threadIdx.x;
    const int y = blockDim.y * blockIdx.y + threadIdx.y;
    const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
    if (x >= height or y >= width) return;
    const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
    const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
    float e = 0;
    for (int px = -r; px <= r; px++){
        for (int py = -r; py <= r; py++){
            const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
            const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
            for (int c = 0; c < channel; c++){
                const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
                e += diff * diff;
            }
        }
    }
    error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'patch_error')


pairwise_patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void pairwise_patch_error(
    const int height,
    const int width,
    const int channel,
    const int patch_size,
    const int pad_size,
    const float* source_a,
    const int* nnf_a,
    const float* source_b,
    const int* nnf_b,
    float* error
) {
    const int r = (patch_size - 1) / 2;
    const int x = blockDim.x * blockIdx.x + threadIdx.x;
    const int y = blockDim.y * blockIdx.y + threadIdx.y;
    const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
    if (x >= height or y >= width) return;
    const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
    const int x_a = nnf_a[z_nnf + 0];
    const int y_a = nnf_a[z_nnf + 1];
    const int x_b = nnf_b[z_nnf + 0];
    const int y_b = nnf_b[z_nnf + 1];
    float e = 0;
    for (int px = -r; px <= r; px++){
        for (int py = -r; py <= r; py++){
            const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
            const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
            for (int c = 0; c < channel; c++){
                const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
                e += diff * diff;
            }
        }
    }
    error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'pairwise_patch_error')


================================================
FILE: diffsynth/extensions/FastBlend/data.py
================================================
import imageio, os
import numpy as np
from PIL import Image


def read_video(file_name):
    reader = imageio.get_reader(file_name)
    video = []
    for frame in reader:
        frame = np.array(frame)
        video.append(frame)
    reader.close()
    return video


def get_video_fps(file_name):
    reader = imageio.get_reader(file_name)
    fps = reader.get_meta_data()["fps"]
    reader.close()
    return fps


def save_video(frames_path, video_path, num_frames, fps):
    writer = imageio.get_writer(video_path, fps=fps, quality=9)
    for i in range(num_frames):
        frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
        writer.append_data(frame)
    writer.close()
    return video_path


class LowMemoryVideo:
    def __init__(self, file_name):
        self.reader = imageio.get_reader(file_name)
    
    def __len__(self):
        return self.reader.count_frames()

    def __getitem__(self, item):
        return np.array(self.reader.get_data(item))

    def __del__(self):
        self.reader.close()


def split_file_name(file_name):
    result = []
    number = -1
    for i in file_name:
        if ord(i)>=ord("0") and ord(i)<=ord("9"):
            if number == -1:
                number = 0
            number = number*10 + ord(i) - ord("0")
        else:
            if number != -1:
                result.append(number)
                number = -1
            result.append(i)
    if number != -1:
        result.append(number)
    result = tuple(result)
    return result


def search_for_images(folder):
    file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
    file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
    file_list = [i[1] for i in sorted(file_list)]
    file_list = [os.path.join(folder, i) for i in file_list]
    return file_list


def read_images(folder):
    file_list = search_for_images(folder)
    frames = [np.array(Image.open(i)) for i in file_list]
    return frames


class LowMemoryImageFolder:
    def __init__(self, folder, file_list=None):
        if file_list is None:
            self.file_list = search_for_images(folder)
        else:
            self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
    
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, item):
        return np.array(Image.open(self.file_list[item]))

    def __del__(self):
        pass


class VideoData:
    def __init__(self, video_file, image_folder, **kwargs):
        if video_file is not None:
            self.data_type = "video"
            self.data = LowMemoryVideo(video_file, **kwargs)
        elif image_folder is not None:
            self.data_type = "images"
            self.data = LowMemoryImageFolder(image_folder, **kwargs)
        else:
            raise ValueError("Cannot open video or image folder")
        self.length = None
        self.height = None
        self.width = None

    def raw_data(self):
        frames = []
        for i in range(self.__len__()):
            frames.append(self.__getitem__(i))
        return frames

    def set_length(self, length):
        self.length = length

    def set_shape(self, height, width):
        self.height = height
        self.width = width

    def __len__(self):
        if self.length is None:
            return len(self.data)
        else:
            return self.length

    def shape(self):
        if self.height is not None and self.width is not None:
            return self.height, self.width
        else:
            height, width, _ = self.__getitem__(0).shape
            return height, width

    def __getitem__(self, item):
        frame = self.data.__getitem__(item)
        height, width, _ = frame.shape
        if self.height is not None and self.width is not None:
            if self.height != height or self.width != width:
                frame = Image.fromarray(frame).resize((self.width, self.height))
                frame = np.array(frame)
        return frame

    def __del__(self):
        pass


================================================
FILE: diffsynth/extensions/FastBlend/patch_match.py
================================================
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
import numpy as np
import cupy as cp
import cv2


class PatchMatcher:
    def __init__(
        self, height, width, channel, minimum_patch_size,
        threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
        random_search_steps=3, random_search_range=4,
        use_mean_target_style=False, use_pairwise_patch_error=False,
        tracking_window_size=0
    ):
        self.height = height
        self.width = width
        self.channel = channel
        self.minimum_patch_size = minimum_patch_size
        self.threads_per_block = threads_per_block
        self.num_iter = num_iter
        self.gpu_id = gpu_id
        self.guide_weight = guide_weight
        self.random_search_steps = random_search_steps
        self.random_search_range = random_search_range
        self.use_mean_target_style = use_mean_target_style
        self.use_pairwise_patch_error = use_pairwise_patch_error
        self.tracking_window_size = tracking_window_size

        self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
        self.pad_size = self.patch_size_list[0] // 2
        self.grid = (
            (height + threads_per_block - 1) // threads_per_block,
            (width + threads_per_block - 1) // threads_per_block
        )
        self.block = (threads_per_block, threads_per_block)

    def pad_image(self, image):
        return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))

    def unpad_image(self, image):
        return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]

    def apply_nnf_to_image(self, nnf, source):
        batch_size = source.shape[0]
        target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
        remapping_kernel(
            self.grid + (batch_size,),
            self.block,
            (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
        )
        return target

    def get_patch_error(self, source, nnf, target):
        batch_size = source.shape[0]
        error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
        patch_error_kernel(
            self.grid + (batch_size,),
            self.block,
            (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
        )
        return error

    def get_pairwise_patch_error(self, source, nnf):
        batch_size = source.shape[0]//2
        error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
        source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
        source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
        pairwise_patch_error_kernel(
            self.grid + (batch_size,),
            self.block,
            (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
        )
        error = error.repeat(2, axis=0)
        return error

    def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
        error_guide = self.get_patch_error(source_guide, nnf, target_guide)
        if self.use_mean_target_style:
            target_style = self.apply_nnf_to_image(nnf, source_style)
            target_style = target_style.mean(axis=0, keepdims=True)
            target_style = target_style.repeat(source_guide.shape[0], axis=0)
        if self.use_pairwise_patch_error:
            error_style = self.get_pairwise_patch_error(source_style, nnf)
        else:
            error_style = self.get_patch_error(source_style, nnf, target_style)
        error = error_guide * self.guide_weight + error_style
        return error

    def clamp_bound(self, nnf):
        nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
        nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
        return nnf

    def random_step(self, nnf, r):
        batch_size = nnf.shape[0]
        step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
        upd_nnf = self.clamp_bound(nnf + step)
        return upd_nnf

    def neighboor_step(self, nnf, d):
        if d==0:
            upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
            upd_nnf[:, :, :, 0] += 1
        elif d==1:
            upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
            upd_nnf[:, :, :, 1] += 1
        elif d==2:
            upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
            upd_nnf[:, :, :, 0] -= 1
        elif d==3:
            upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
            upd_nnf[:, :, :, 1] -= 1
        upd_nnf = self.clamp_bound(upd_nnf)
        return upd_nnf
        
    def shift_nnf(self, nnf, d):
        if d>0:
            d = min(nnf.shape[0], d)
            upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
        else:
            d = max(-nnf.shape[0], d)
            upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
        return upd_nnf
    
    def track_step(self, nnf, d):
        if self.use_pairwise_patch_error:
            upd_nnf = cp.zeros_like(nnf)
            upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
            upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
        else:
            upd_nnf = self.shift_nnf(nnf, d)
        return upd_nnf

    def C(self, n, m):
        # not used
        c = 1
        for i in range(1, n+1):
            c *= i
        for i in range(1, m+1):
            c //= i
        for i in range(1, n-m+1):
            c //= i
        return c

    def bezier_step(self, nnf, r):
        # not used
        n = r * 2 - 1
        upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
        for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
            if d>0:
                ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
            elif d<0:
                ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
            upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
        upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
        return upd_nnf

    def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
        upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
        upd_idx = (upd_err < err)
        nnf[upd_idx] = upd_nnf[upd_idx]
        err[upd_idx] = upd_err[upd_idx]
        return nnf, err

    def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
        for d in cp.random.permutation(4):
            upd_nnf = self.neighboor_step(nnf, d)
            nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
        return nnf, err
        
    def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
        for i in range(self.random_search_steps):
            upd_nnf = self.random_step(nnf, self.random_search_range)
            nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
        return nnf, err

    def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
        for d in range(1, self.tracking_window_size + 1):
            upd_nnf = self.track_step(nnf, d)
            nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
            upd_nnf = self.track_step(nnf, -d)
            nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
        return nnf, err

    def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
        nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
        nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
        nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
        return nnf, err

    def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
        with cp.cuda.Device(self.gpu_id):
            source_guide = self.pad_image(source_guide)
            target_guide = self.pad_image(target_guide)
            source_style = self.pad_image(source_style)
            for it in range(self.num_iter):
                self.patch_size = self.patch_size_list[it]
                target_style = self.apply_nnf_to_image(nnf, source_style)
                err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
                nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
            target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
        return nnf, target_style


class PyramidPatchMatcher:
    def __init__(
        self, image_height, image_width, channel, minimum_patch_size,
        threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
        use_mean_target_style=False, use_pairwise_patch_error=False,
        tracking_window_size=0,
        initialize="identity"
    ):
        maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
        self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
        self.pyramid_heights = []
        self.pyramid_widths = []
        self.patch_matchers = []
        self.minimum_patch_size = minimum_patch_size
        self.num_iter = num_iter
        self.gpu_id = gpu_id
        self.initialize = initialize
        for level in range(self.pyramid_level):
            height = image_height//(2**(self.pyramid_level - 1 - level))
            width = image_width//(2**(self.pyramid_level - 1 - level))
            self.pyramid_heights.append(height)
            self.pyramid_widths.append(width)
            self.patch_matchers.append(PatchMatcher(
                height, width, channel, minimum_patch_size=minimum_patch_size,
                threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
                use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
                tracking_window_size=tracking_window_size
            ))

    def resample_image(self, images, level):
        height, width = self.pyramid_heights[level], self.pyramid_widths[level]
        images = images.get()
        images_resample = []
        for image in images:
            image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
            images_resample.append(image_resample)
        images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
        return images_resample

    def initialize_nnf(self, batch_size):
        if self.initialize == "random":
            height, width = self.pyramid_heights[0], self.pyramid_widths[0]
            nnf = cp.stack([
                cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
                cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
            ], axis=3)
        elif self.initialize == "identity":
            height, width = self.pyramid_heights[0], self.pyramid_widths[0]
            nnf = cp.stack([
                cp.repeat(cp.arange(height), width).reshape(height, width),
                cp.tile(cp.arange(width), height).reshape(height, width)
            ], axis=2)
            nnf = cp.stack([nnf] * batch_size)
        else:
            raise NotImplementedError()
        return nnf

    def update_nnf(self, nnf, level):
        # upscale
        nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
        nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
        nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
        # check if scale is 2
        height, width = self.pyramid_heights[level], self.pyramid_widths[level]
        if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
            nnf = nnf.get().astype(np.float32)
            nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
            nnf = cp.array(np.stack(nnf), dtype=cp.int32)
            nnf = self.patch_matchers[level].clamp_bound(nnf)
        return nnf

    def apply_nnf_to_image(self, nnf, image):
        with cp.cuda.Device(self.gpu_id):
            image = self.patch_matchers[-1].pad_image(image)
            image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
        return image

    def estimate_nnf(self, source_guide, target_guide, source_style):
        with cp.cuda.Device(self.gpu_id):
            if not isinstance(source_guide, cp.ndarray):
                source_guide = cp.array(source_guide, dtype=cp.float32)
            if not isinstance(target_guide, cp.ndarray):
                target_guide = cp.array(target_guide, dtype=cp.float32)
            if not isinstance(source_style, cp.ndarray):
                source_style = cp.array(source_style, dtype=cp.float32)
            for level in range(self.pyramid_level):
                nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
                source_guide_ = self.resample_image(source_guide, level)
                target_guide_ = self.resample_image(target_guide, level)
                source_style_ = self.resample_image(source_style, level)
                nnf, target_style = self.patch_matchers[level].estimate_nnf(
                    source_guide_, target_guide_, source_style_, nnf
                )
        return nnf.get(), target_style.get()


================================================
FILE: diffsynth/extensions/FastBlend/runners/__init__.py
================================================
from .accurate import AccurateModeRunner
from .fast import FastModeRunner
from .balanced import BalancedModeRunner
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner


================================================
FILE: diffsynth/extensions/FastBlend/runners/accurate.py
================================================
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm


class AccurateModeRunner:
    def __init__(self):
        pass

    def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
        patch_match_engine = PyramidPatchMatcher(
            image_height=frames_style[0].shape[0],
            image_width=frames_style[0].shape[1],
            channel=3,
            use_mean_target_style=True,
            **ebsynth_config
        )
        # run
        n = len(frames_style)
        for target in tqdm(range(n), desc=desc):
            l, r = max(target - window_size, 0), min(target + window_size + 1, n)
            remapped_frames = []
            for i in range(l, r, batch_size):
                j = min(i + batch_size, r)
                source_guide = np.stack([frames_guide[source] for source in range(i, j)])
                target_guide = np.stack([frames_guide[target]] * (j - i))
                source_style = np.stack([frames_style[source] for source in range(i, j)])
                _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
                remapped_frames.append(target_style)
            frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
            frame = frame.clip(0, 255).astype("uint8")
            if save_path is not None:
                Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))

================================================
FILE: diffsynth/extensions/FastBlend/runners/balanced.py
================================================
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm


class BalancedModeRunner:
    def __init__(self):
        pass

    def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
        patch_match_engine = PyramidPatchMatcher(
            image_height=frames_style[0].shape[0],
            image_width=frames_style[0].shape[1],
            channel=3,
            **ebsynth_config
        )
        # tasks
        n = len(frames_style)
        tasks = []
        for target in range(n):
            for source in range(target - window_size, target + window_size + 1):
                if source >= 0 and source < n and source != target:
                    tasks.append((source, target))
        # run
        frames = [(None, 1) for i in range(n)]
        for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
            tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
            source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
            target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
            source_style = np.stack([frames_style[source] for source, target in tasks_batch])
            _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
            for (source, target), result in zip(tasks_batch, target_style):
                frame, weight = frames[target]
                if frame is None:
                    frame = frames_style[target]
                frames[target] = (
                    frame * (weight / (weight + 1)) + result / (weight + 1),
                    weight + 1
                )
                if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
                    frame = frame.clip(0, 255).astype("uint8")
                    if save_path is not None:
                        Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
                    frames[target] = (None, 1)


================================================
FILE: diffsynth/extensions/FastBlend/runners/fast.py
================================================
from ..patch_match import PyramidPatchMatcher
import functools, os
import numpy as np
from PIL import Image
from tqdm import tqdm


class TableManager:
    def __init__(self):
        pass

    def task_list(self, n):
        tasks = []
        max_level = 1
        while (1<<max_level)<=n:
            max_level += 1
        for i in range(n):
            j = i
            for level in range(max_level):
                if i&(1<<level):
                    continue
                j |= 1<<level
                if j>=n:
                    break
                meta_data = {
                    "source": i,
                    "target": j,
                    "level": level + 1
                }
                tasks.append(meta_data)
        tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
        return tasks
    
    def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
        n = len(frames_guide)
        tasks = self.task_list(n)
        remapping_table = [[(frames_style[i], 1)] for i in range(n)]
        for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
            tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
            source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
            target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
            source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
            _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
            for task, result in zip(tasks_batch, target_style):
                target, level = task["target"], task["level"]
                if len(remapping_table[target])==level:
                    remapping_table[target].append((result, 1))
                else:
                    frame, weight = remapping_table[target][level]
                    remapping_table[target][level] = (
                        frame * (weight / (weight + 1)) + result / (weight + 1),
                        weight + 1
                    )
        return remapping_table

    def remapping_table_to_blending_table(self, table):
        for i in range(len(table)):
            for j in range(1, len(table[i])):
                frame_1, weight_1 = table[i][j-1]
                frame_2, weight_2 = table[i][j]
                frame = (frame_1 + frame_2) / 2
                weight = weight_1 + weight_2
                table[i][j] = (frame, weight)
        return table

    def tree_query(self, leftbound, rightbound):
        node_list = []
        node_index = rightbound
        while node_index>=leftbound:
            node_level = 0
            while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
                node_level += 1
            node_list.append((node_index, node_level))
            node_index -= 1<<node_level
        return node_list

    def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
        n = len(blending_table)
        tasks = []
        frames_result = []
        for target in range(n):
            node_list = self.tree_query(max(target-window_size, 0), target)
            for source, level in node_list:
                if source!=target:
                    meta_data = {
                        "source": source,
                        "target": target,
                        "level": level
                    }
                    tasks.append(meta_data)
                else:
                    frames_result.append(blending_table[target][level])
        for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
            tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
            source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
            target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
            source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
            _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
            for task, frame_2 in zip(tasks_batch, target_style):
                source, target, level = task["source"], task["target"], task["level"]
                frame_1, weight_1 = frames_result[target]
                weight_2 = blending_table[source][level][1]
                weight = weight_1 + weight_2
                frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
                frames_result[target] = (frame, weight)
        return frames_result


class FastModeRunner:
    def __init__(self):
        pass

    def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
        frames_guide = frames_guide.raw_data()
        frames_style = frames_style.raw_data()
        table_manager = TableManager()
        patch_match_engine = PyramidPatchMatcher(
            image_height=frames_style[0].shape[0],
            image_width=frames_style[0].shape[1],
            channel=3,
            **ebsynth_config
        )
        # left part
        table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
        table_l = table_manager.remapping_table_to_blending_table(table_l)
        table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
        # right part
        table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
        table_r = table_manager.remapping_table_to_blending_table(table_r)
        table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
        # merge
        frames = []
        for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
            weight_m = -1
            weight = weight_l + weight_m + weight_r
            frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
            frames.append(frame)
        frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
        if save_path is not None:
            for target, frame in enumerate(frames):
                Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))


================================================
FILE: diffsynth/extensions/FastBlend/runners/interpolation.py
================================================
from ..patch_match import PyramidPatchMatcher
import os
import numpy as np
from PIL import Image
from tqdm import tqdm


class InterpolationModeRunner:
    def __init__(self):
        pass

    def get_index_dict(self, index_style):
        index_dict = {}
        for i, index in enumerate(index_style):
            index_dict[index] = i
        return index_dict

    def get_weight(self, l, m, r):
        weight_l, weight_r = abs(m - r), abs(m - l)
        if weight_l + weight_r == 0:
            weight_l, weight_r = 0.5, 0.5
        else:
            weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
        return weight_l, weight_r

    def get_task_group(self, index_style, n):
        task_group = []
        index_style = sorted(index_style)
        # first frame
        if index_style[0]>0:
            tasks = []
            for m in range(index_style[0]):
                tasks.append((index_style[0], m, index_style[0]))
            task_group.append(tasks)
        # middle frames
        for l, r in zip(index_style[:-1], index_style[1:]):
            tasks = []
            for m in range(l, r):
                tasks.append((l, m, r))
            task_group.append(tasks)
        # last frame
        tasks = []
        for m in range(index_style[-1], n):
            tasks.append((index_style[-1], m, index_style[-1]))
        task_group.append(tasks)
        return task_group

    def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
        patch_match_engine = PyramidPatchMatcher(
            image_height=frames_style[0].shape[0],
            image_width=frames_style[0].shape[1],
            channel=3,
            use_mean_target_style=False,
            use_pairwise_patch_error=True,
            **ebsynth_config
        )
        # task
        index_dict = self.get_index_dict(index_style)
        task_group = self.get_task_group(index_style, len(frames_guide))
        # run
        for tasks in task_group:
            index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
            for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
                tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
                source_guide, target_guide, source_style = [], [], []
                for l, m, r in tasks_batch:
                    # l -> m
                    source_guide.append(frames_guide[l])
                    target_guide.append(frames_guide[m])
                    source_style.append(frames_style[index_dict[l]])
                    # r -> m
                    source_guide.append(frames_guide[r])
                    target_guide.append(frames_guide[m])
                    source_style.append(frames_style[index_dict[r]])
                source_guide = np.stack(source_guide)
                target_guide = np.stack(target_guide)
                source_style = np.stack(source_style)
                _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
                if save_path is not None:
                    for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
                        weight_l, weight_r = self.get_weight(l, m, r)
                        frame = frame_l * weight_l + frame_r * weight_r
                        frame = frame.clip(0, 255).astype("uint8")
                        Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))


class InterpolationModeSingleFrameRunner:
    def __init__(self):
        pass

    def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
        # check input
        tracking_window_size = ebsynth_config["tracking_window_size"]
        if tracking_window_size * 2 >= batch_size:
            raise ValueError("batch_size should be larger than track_window_size * 2")
        frame_style = frames_style[0]
        frame_guide = frames_guide[index_style[0]]
        patch_match_engine = PyramidPatchMatcher(
            image_height=frame_style.shape[0],
            image_width=frame_style.shape[1],
            channel=3,
            **ebsynth_config
        )
        # run
        frame_id, n = 0, len(frames_guide)
        for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
            if i + batch_size > n:
                l, r = max(n - batch_size, 0), n
            else:
                l, r = i, i + batch_size
            source_guide = np.stack([frame_guide] * (r-l))
            target_guide = np.stack([frames_guide[i] for i in range(l, r)])
            source_style = np.stack([frame_style] * (r-l))
            _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
            for i, frame in zip(range(l, r), target_style):
                if i==frame_id:
                    frame = frame.clip(0, 255).astype("uint8")
                    Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
                    frame_id += 1
                if r < n and r-frame_id <= tracking_window_size:
                    break


================================================
FILE: diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
================================================
from .blip_pretrain import *


================================================
FILE: diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
================================================
'''
 * Adapted from BLIP (https://github.com/salesforce/BLIP)
'''

import warnings
warnings.filterwarnings("ignore")

import torch
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from transformers import BertTokenizer
from .vit import VisionTransformer, interpolate_pos_embed


def default_bert():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
    model_path = os.path.join(project_root, 'models', 'QualityMetric')
    return os.path.join(model_path, "bert-base-uncased")


def init_tokenizer(bert_model_path):
    tokenizer = BertTokenizer.from_pretrained(bert_model_path)
    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
    return tokenizer


def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
        
    assert vit in ['base', 'large'], "vit parameter must be base or large"
    if vit=='base':
        vision_width = 768
        visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 
                                           num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
                                           drop_path_rate=0 or drop_path_rate
                                          )   
    elif vit=='large':
        vision_width = 1024
        visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 
                                           num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
                                           drop_path_rate=0.1 or drop_path_rate
                                          )   
    return visual_encoder, vision_width


def is_url(url_or_filename):
    parsed = urlparse(url_or_filename)
    return parsed.scheme in ("http", "https")

def load_checkpoint(model,url_or_filename):
    if is_url(url_or_filename):
        cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
        checkpoint = torch.load(cached_file, map_location='cpu') 
    elif os.path.isfile(url_or_filename):        
        checkpoint = torch.load(url_or_filename, map_location='cpu') 
    else:
        raise RuntimeError('checkpoint url or path is invalid')
        
    state_dict = checkpoint['model']
    
    state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 
    if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
        state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
                                                                         model.visual_encoder_m)    
    for key in model.state_dict().keys():
        if key in state_dict.keys():
            if state_dict[key].shape!=model.state_dict()[key].shape:
                print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
                del state_dict[key]
    
    msg = model.load_state_dict(state_dict,strict=False)
    print('load checkpoint from %s'%url_or_filename)  
    return model,msg
    


================================================
FILE: diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
================================================
'''
 * Adapted from BLIP (https://github.com/salesforce/BLIP)
'''

import transformers
transformers.logging.set_verbosity_error()

from torch import nn
import os
from .med import BertConfig, BertModel
from .blip import create_vit, init_tokenizer

class BLIP_Pretrain(nn.Module):
    def __init__(self,                 
                 med_config = "med_config.json",  
                 image_size = 224,
                 vit = 'base',
                 vit_grad_ckpt = False,
                 vit_ckpt_layer = 0,                    
                 embed_dim = 256,     
                 queue_size = 57600,
                 momentum = 0.995,
                 bert_model_path = ""
                 ):
        """
        Args:
            med_config (str): path for the mixture of encoder-decoder model's configuration file
            image_size (int): input image size
            vit (str): model size of vision transformer
        """               
        super().__init__()
        
        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
        
        self.tokenizer = init_tokenizer(bert_model_path)   
        encoder_config = BertConfig.from_json_file(med_config)
        encoder_config.encoder_width = vision_width
        self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)

        text_width = self.text_encoder.config.hidden_size
        
        self.vision_proj = nn.Linear(vision_width, embed_dim)
        self.text_proj = nn.Linear(text_width, embed_dim)



================================================
FILE: diffsynth/extensions/ImageQualityMetric/BLIP/med.py
================================================
'''
 * Adapted from BLIP (https://github.com/salesforce/BLIP)
 * Based on huggingface code base
 * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
'''

import math
from typing import Tuple

import torch
from torch import Tensor, device, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.activations import ACT2FN
from transformers.file_utils import (
    ModelOutput,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from transformers.modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig


logger = logging.get_logger(__name__)


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word and position embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        
        self.config = config

    def forward(
        self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        embeddings = inputs_embeds

        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
    def __init__(self, config, is_cross_attention):
        super().__init__()
        self.config = config
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
        
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        if is_cross_attention:
            self.key = nn.Linear(config.encoder_width, self.all_head_size)
            self.value = nn.Linear(config.encoder_width, self.all_head_size)
        else:
            self.key = nn.Linear(config.hidden_size, self.all_head_size)
            self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
        self.save_attention = False   
            
    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients
        
    def get_attn_gradients(self):
        return self.attn_gradients
    
    def save_attention_map(self, attention_map):
        self.attention_map = attention_map
        
    def get_attention_map(self):
        return self.attention_map
    
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        if is_cross_attention and self.save_attention:
            self.save_attention_map(attention_probs)
            attention_probs.register_hook(self.save_attn_gradients)         

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs_dropped = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs_dropped = attention_probs_dropped * head_mask

        context_layer = torch.matmul(attention_probs_dropped, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        outputs = outputs + (past_key_value,)
        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.self = BertSelfAttention(config, is_cross_attention)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)      
        self.layer_num = layer_num          
        if self.config.add_cross_attention:
            self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        mode=None,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]

        outputs = self_attention_outputs[1:-1]
        present_key_value = self_attention_outputs[-1]

        if mode=='multimodal':
            assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"

            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs

        outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        mode='multimodal',
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
               
        for i in range(self.config.num_hidden_layers):
            layer_module = self.layer[i]
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    mode=mode,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                    mode=mode,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class BertModel(BertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)
        
        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()
 

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    
    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if is_decoder:
                batch_size, seq_length = input_shape

                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)
   
                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    causal_mask = torch.cat(
                        [
                            torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
                            causal_mask,
                        ],
                        axis=-1,
                    )                     

                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        is_decoder=False,
        mode='multimodal',
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            batch_size, seq_length = input_shape
            device = input_ids.device
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size, seq_length = input_shape
            device = inputs_embeds.device
        elif encoder_embeds is not None:    
            input_shape = encoder_embeds.size()[:-1]
            batch_size, seq_length = input_shape 
            device = encoder_embeds.device
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
            
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, 
                                                                                 device, is_decoder)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if encoder_hidden_states is not None:
            if type(encoder_hidden_states) == list:
                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
            else:
                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            
            if type(encoder_attention_mask) == list:
                encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
            elif encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
            else:    
                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        
        if encoder_embeds is None:
            embedding_output = self.embeddings(
                input_ids=input_ids,
                position_ids=position_ids,
                inputs_embeds=inputs_embeds,
                past_key_values_length=past_key_values_length,
            )
        else:
            embedding_output = encoder_embeds
            
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            mode=mode,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )



class BertLMHeadModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,            
        is_decoder=True,
        reduction='mean',
        mode='multimodal', 
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        Returns:
        Example::
            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
            >>> import torch
            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
            >>> config = BertConfig.from_pretrained("bert-base-cased")
            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
            >>> outputs = model(**inputs)
            >>> prediction_logits = outputs.logits
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if labels is not None:
            use_cache = False

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
            mode=mode,
        )
        
        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)
        
        if return_logits:
            return prediction_scores[:, :-1, :].contiguous()  

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) 
            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            if reduction=='none':
                lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)               

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((lm_loss,) + output) if lm_loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=lm_loss,
            logits=prediction_scores,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
        input_shape = input_ids.shape
        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids": input_ids, 
            "attention_mask": attention_mask, 
            "past_key_values": past,
            "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
            "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
            "is_decoder": True,
        }

    def _reorder_cache(self, past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
        return reordered_past


================================================
FILE: diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
================================================
'''
 * Adapted from BLIP (https://github.com/salesforce/BLIP)
 * Based on timm code base
 * https://github.com/rwightman/pytorch-image-models/tree/master/timm
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from timm.models.vision_transformer import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv

# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None
        
    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients
        
    def get_attn_gradients(self):
        return self.attn_gradients
    
    def save_attention_map(self, attention_map):
        self.attention_map = attention_map
        
    def get_attention_map(self):
        return self.attention_map
    
    def forward(self, x, register_hook=False):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
                
        if register_hook:
            self.save_attention_map(attn)
            attn.register_hook(self.save_attn_gradients)        

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        # if use_grad_checkpointing:
        #     self.attn = checkpoint_wrapper(self.attn)
        #     self.mlp = checkpoint_wrapper(self.mlp)

    def forward(self, x, register_hook=False):
        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

    
class VisionTransformer(nn.Module):
    """ Vision Transformer
    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
        https://arxiv.org/abs/2010.11929
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 
                 use_grad_checkpointing=False, ckpt_layer=0):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)

        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
            )
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward(self, x, register_blk=-1):
        B = x.shape[0]
        x = self.pat
Download .txt
gitextract_iigplzbn/

├── .gitignore
├── README.md
├── diffsynth/
│   ├── __init__.py
│   ├── configs/
│   │   ├── __init__.py
│   │   └── model_config.py
│   ├── controlnets/
│   │   ├── __init__.py
│   │   ├── controlnet_unit.py
│   │   └── processors.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── simple_text_image.py
│   │   └── video.py
│   ├── extensions/
│   │   ├── ESRGAN/
│   │   │   └── __init__.py
│   │   ├── FastBlend/
│   │   │   ├── __init__.py
│   │   │   ├── api.py
│   │   │   ├── cupy_kernels.py
│   │   │   ├── data.py
│   │   │   ├── patch_match.py
│   │   │   └── runners/
│   │   │       ├── __init__.py
│   │   │       ├── accurate.py
│   │   │       ├── balanced.py
│   │   │       ├── fast.py
│   │   │       └── interpolation.py
│   │   ├── ImageQualityMetric/
│   │   │   ├── BLIP/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── blip.py
│   │   │   │   ├── blip_pretrain.py
│   │   │   │   ├── med.py
│   │   │   │   └── vit.py
│   │   │   ├── __init__.py
│   │   │   ├── aesthetic.py
│   │   │   ├── clip.py
│   │   │   ├── config.py
│   │   │   ├── hps.py
│   │   │   ├── imagereward.py
│   │   │   ├── mps.py
│   │   │   ├── open_clip/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── coca_model.py
│   │   │   │   ├── constants.py
│   │   │   │   ├── factory.py
│   │   │   │   ├── generation_utils.py
│   │   │   │   ├── hf_configs.py
│   │   │   │   ├── hf_model.py
│   │   │   │   ├── loss.py
│   │   │   │   ├── model.py
│   │   │   │   ├── model_configs/
│   │   │   │   │   └── ViT-H-14.json
│   │   │   │   ├── modified_resnet.py
│   │   │   │   ├── openai.py
│   │   │   │   ├── pretrained.py
│   │   │   │   ├── push_to_hf_hub.py
│   │   │   │   ├── timm_model.py
│   │   │   │   ├── tokenizer.py
│   │   │   │   ├── transform.py
│   │   │   │   ├── transformer.py
│   │   │   │   ├── utils.py
│   │   │   │   └── version.py
│   │   │   ├── pickscore.py
│   │   │   └── trainer/
│   │   │       ├── __init__.py
│   │   │       └── models/
│   │   │           ├── __init__.py
│   │   │           ├── base_model.py
│   │   │           ├── clip_model.py
│   │   │           └── cross_modeling.py
│   │   ├── RIFE/
│   │   │   └── __init__.py
│   │   └── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── cog_dit.py
│   │   ├── cog_vae.py
│   │   ├── downloader.py
│   │   ├── flux_controlnet.py
│   │   ├── flux_dit.py
│   │   ├── flux_ipadapter.py
│   │   ├── flux_text_encoder.py
│   │   ├── flux_vae.py
│   │   ├── hunyuan_dit.py
│   │   ├── hunyuan_dit_text_encoder.py
│   │   ├── hunyuan_video_dit.py
│   │   ├── hunyuan_video_text_encoder.py
│   │   ├── hunyuan_video_vae_decoder.py
│   │   ├── hunyuan_video_vae_encoder.py
│   │   ├── kolors_text_encoder.py
│   │   ├── lora.py
│   │   ├── model_manager.py
│   │   ├── omnigen.py
│   │   ├── sd3_dit.py
│   │   ├── sd3_text_encoder.py
│   │   ├── sd3_vae_decoder.py
│   │   ├── sd3_vae_encoder.py
│   │   ├── sd_controlnet.py
│   │   ├── sd_ipadapter.py
│   │   ├── sd_motion.py
│   │   ├── sd_text_encoder.py
│   │   ├── sd_unet.py
│   │   ├── sd_vae_decoder.py
│   │   ├── sd_vae_encoder.py
│   │   ├── sdxl_controlnet.py
│   │   ├── sdxl_ipadapter.py
│   │   ├── sdxl_motion.py
│   │   ├── sdxl_text_encoder.py
│   │   ├── sdxl_unet.py
│   │   ├── sdxl_vae_decoder.py
│   │   ├── sdxl_vae_encoder.py
│   │   ├── stepvideo_dit.py
│   │   ├── stepvideo_text_encoder.py
│   │   ├── stepvideo_vae.py
│   │   ├── svd_image_encoder.py
│   │   ├── svd_unet.py
│   │   ├── svd_vae_decoder.py
│   │   ├── svd_vae_encoder.py
│   │   ├── tiler.py
│   │   ├── utils.py
│   │   ├── wan_video_dit.py
│   │   ├── wan_video_image_encoder.py
│   │   ├── wan_video_text_encoder.py
│   │   └── wan_video_vae.py
│   ├── pipelines/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── cog_video.py
│   │   ├── dancer.py
│   │   ├── flux_image.py
│   │   ├── hunyuan_image.py
│   │   ├── hunyuan_video.py
│   │   ├── omnigen_image.py
│   │   ├── pipeline_runner.py
│   │   ├── sd3_image.py
│   │   ├── sd_image.py
│   │   ├── sd_video.py
│   │   ├── sdxl_image.py
│   │   ├── sdxl_video.py
│   │   ├── step_video.py
│   │   ├── svd_video.py
│   │   ├── wan_video.py
│   │   └── wan_video_syncammaster.py
│   ├── processors/
│   │   ├── FastBlend.py
│   │   ├── PILEditor.py
│   │   ├── RIFE.py
│   │   ├── __init__.py
│   │   ├── base.py
│   │   └── sequencial_processor.py
│   ├── prompters/
│   │   ├── __init__.py
│   │   ├── base_prompter.py
│   │   ├── cog_prompter.py
│   │   ├── flux_prompter.py
│   │   ├── hunyuan_dit_prompter.py
│   │   ├── hunyuan_video_prompter.py
│   │   ├── kolors_prompter.py
│   │   ├── omnigen_prompter.py
│   │   ├── omost.py
│   │   ├── prompt_refiners.py
│   │   ├── sd3_prompter.py
│   │   ├── sd_prompter.py
│   │   ├── sdxl_prompter.py
│   │   ├── stepvideo_prompter.py
│   │   └── wan_prompter.py
│   ├── schedulers/
│   │   ├── __init__.py
│   │   ├── continuous_ode.py
│   │   ├── ddim.py
│   │   └── flow_match.py
│   ├── tokenizer_configs/
│   │   ├── __init__.py
│   │   ├── cog/
│   │   │   └── tokenizer/
│   │   │       ├── added_tokens.json
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       └── tokenizer_config.json
│   │   ├── flux/
│   │   │   ├── tokenizer_1/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   └── tokenizer_2/
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       ├── tokenizer.json
│   │   │       └── tokenizer_config.json
│   │   ├── hunyuan_dit/
│   │   │   ├── tokenizer/
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   ├── vocab.txt
│   │   │   │   └── vocab_org.txt
│   │   │   └── tokenizer_t5/
│   │   │       ├── config.json
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       └── tokenizer_config.json
│   │   ├── hunyuan_video/
│   │   │   ├── tokenizer_1/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   └── tokenizer_2/
│   │   │       ├── preprocessor_config.json
│   │   │       ├── special_tokens_map.json
│   │   │       ├── tokenizer.json
│   │   │       └── tokenizer_config.json
│   │   ├── kolors/
│   │   │   └── tokenizer/
│   │   │       ├── tokenizer.model
│   │   │       ├── tokenizer_config.json
│   │   │       └── vocab.txt
│   │   ├── stable_diffusion/
│   │   │   └── tokenizer/
│   │   │       ├── merges.txt
│   │   │       ├── special_tokens_map.json
│   │   │       ├── tokenizer_config.json
│   │   │       └── vocab.json
│   │   ├── stable_diffusion_3/
│   │   │   ├── tokenizer_1/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   ├── tokenizer_2/
│   │   │   │   ├── merges.txt
│   │   │   │   ├── special_tokens_map.json
│   │   │   │   ├── tokenizer_config.json
│   │   │   │   └── vocab.json
│   │   │   └── tokenizer_3/
│   │   │       ├── special_tokens_map.json
│   │   │       ├── spiece.model
│   │   │       ├── tokenizer.json
│   │   │       └── tokenizer_config.json
│   │   └── stable_diffusion_xl/
│   │       └── tokenizer_2/
│   │           ├── merges.txt
│   │           ├── special_tokens_map.json
│   │           ├── tokenizer_config.json
│   │           └── vocab.json
│   ├── trainers/
│   │   ├── __init__.py
│   │   └── text_to_image.py
│   └── vram_management/
│       ├── __init__.py
│       └── layers.py
├── download_wan2.1.py
├── example_test_data/
│   ├── cameras/
│   │   └── camera_extrinsics.json
│   └── metadata.csv
├── generate_sample_list.py
├── inference_syncammaster.py
├── models/
│   └── SynCamMaster/
│       └── checkpoints/
│           └── Put SynCamMaster ckpt file here.txt
├── requirements.txt
├── setup.py
├── train_syncammaster.py
└── vis_cam.py
Download .txt
SYMBOL INDEX (2289 symbols across 137 files)

FILE: diffsynth/controlnets/controlnet_unit.py
  class ControlNetConfigUnit (line 6) | class ControlNetConfigUnit:
    method __init__ (line 7) | def __init__(self, processor_id: Processor_id, model_path, scale=1.0, ...
  class ControlNetUnit (line 14) | class ControlNetUnit:
    method __init__ (line 15) | def __init__(self, processor, model, scale=1.0):
  class MultiControlNetManager (line 21) | class MultiControlNetManager:
    method __init__ (line 22) | def __init__(self, controlnet_units=[]):
    method cpu (line 27) | def cpu(self):
    method to (line 31) | def to(self, device):
    method process_image (line 37) | def process_image(self, image, processor_id=None):
    method __call__ (line 48) | def __call__(
  class FluxMultiControlNetManager (line 68) | class FluxMultiControlNetManager(MultiControlNetManager):
    method __init__ (line 69) | def __init__(self, controlnet_units=[]):
    method process_image (line 72) | def process_image(self, image, processor_id=None):
    method __call__ (line 79) | def __call__(self, conditionings, **kwargs):

FILE: diffsynth/controlnets/processors.py
  class Annotator (line 8) | class Annotator:
    method __init__ (line 9) | def __init__(self, processor_id: Processor_id, model_path="models/Anno...
    method to (line 42) | def to(self,device):
    method __call__ (line 47) | def __call__(self, image, mask=None):

FILE: diffsynth/data/simple_text_image.py
  class TextImageDataset (line 8) | class TextImageDataset(torch.utils.data.Dataset):
    method __init__ (line 9) | def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, w...
    method __getitem__ (line 26) | def __getitem__(self, index):
    method __len__ (line 40) | def __len__(self):

FILE: diffsynth/data/video.py
  class LowMemoryVideo (line 7) | class LowMemoryVideo:
    method __init__ (line 8) | def __init__(self, file_name):
    method __len__ (line 11) | def __len__(self):
    method __getitem__ (line 14) | def __getitem__(self, item):
    method __del__ (line 17) | def __del__(self):
  function split_file_name (line 21) | def split_file_name(file_name):
  function search_for_images (line 40) | def search_for_images(folder):
  class LowMemoryImageFolder (line 48) | class LowMemoryImageFolder:
    method __init__ (line 49) | def __init__(self, folder, file_list=None):
    method __len__ (line 55) | def __len__(self):
    method __getitem__ (line 58) | def __getitem__(self, item):
    method __del__ (line 61) | def __del__(self):
  function crop_and_resize (line 65) | def crop_and_resize(image, height, width):
  class VideoData (line 81) | class VideoData:
    method __init__ (line 82) | def __init__(self, video_file=None, image_folder=None, height=None, wi...
    method raw_data (line 94) | def raw_data(self):
    method set_length (line 100) | def set_length(self, length):
    method set_shape (line 103) | def set_shape(self, height, width):
    method __len__ (line 107) | def __len__(self):
    method shape (line 113) | def shape(self):
    method __getitem__ (line 120) | def __getitem__(self, item):
    method __del__ (line 128) | def __del__(self):
    method save_images (line 131) | def save_images(self, folder):
  function save_video (line 138) | def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
  function save_frames (line 145) | def save_frames(frames, save_path):

FILE: diffsynth/extensions/ESRGAN/__init__.py
  class ResidualDenseBlock (line 7) | class ResidualDenseBlock(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, num_feat=64, num_grow_ch=32):
    method forward (line 18) | def forward(self, x):
  class RRDB (line 27) | class RRDB(torch.nn.Module):
    method __init__ (line 29) | def __init__(self, num_feat, num_grow_ch=32):
    method forward (line 35) | def forward(self, x):
  class RRDBNet (line 42) | class RRDBNet(torch.nn.Module):
    method __init__ (line 44) | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=2...
    method forward (line 56) | def forward(self, x):
    method state_dict_converter (line 70) | def state_dict_converter():
  class RRDBNetStateDictConverter (line 74) | class RRDBNetStateDictConverter:
    method __init__ (line 75) | def __init__(self):
    method from_diffusers (line 78) | def from_diffusers(self, state_dict):
    method from_civitai (line 81) | def from_civitai(self, state_dict):
  class ESRGAN (line 85) | class ESRGAN(torch.nn.Module):
    method __init__ (line 86) | def __init__(self, model):
    method from_model_manager (line 91) | def from_model_manager(model_manager):
    method process_image (line 94) | def process_image(self, image):
    method process_images (line 98) | def process_images(self, images):
    method decode_images (line 103) | def decode_images(self, images):
    method upscale (line 109) | def upscale(self, images, batch_size=4, progress_bar=lambda x:x):

FILE: diffsynth/extensions/FastBlend/__init__.py
  class FastBlendSmoother (line 7) | class FastBlendSmoother:
    method __init__ (line 8) | def __init__(self):
    method from_model_manager (line 22) | def from_model_manager(model_manager):
    method run (line 26) | def run(self, frames_guide, frames_style, batch_size, window_size, ebs...
    method __call__ (line 54) | def __call__(self, rendered_frames, original_frames=None, **kwargs):

FILE: diffsynth/extensions/FastBlend/api.py
  function check_input_for_blending (line 7) | def check_input_for_blending(video_guide, video_guide_folder, video_styl...
  function smooth_video (line 25) | def smooth_video(
  class KeyFrameMatcher (line 92) | class KeyFrameMatcher:
    method __init__ (line 93) | def __init__(self):
    method extract_number_from_filename (line 96) | def extract_number_from_filename(self, file_name):
    method extract_number_from_filenames (line 113) | def extract_number_from_filenames(self, file_names):
    method match_using_filename (line 121) | def match_using_filename(self, file_names_a, file_names_b):
    method match_using_numbers (line 131) | def match_using_numbers(self, file_names_a, file_names_b):
    method match_filenames (line 143) | def match_filenames(self, file_names_a, file_names_b):
  function detect_frames (line 151) | def detect_frames(frames_path, keyframes_path):
  function check_input_for_interpolating (line 182) | def check_input_for_interpolating(frames_path, keyframes_path):
  function interpolate_video (line 202) | def interpolate_video(
  function on_ui_tabs (line 257) | def on_ui_tabs():

FILE: diffsynth/extensions/FastBlend/data.py
  function read_video (line 6) | def read_video(file_name):
  function get_video_fps (line 16) | def get_video_fps(file_name):
  function save_video (line 23) | def save_video(frames_path, video_path, num_frames, fps):
  class LowMemoryVideo (line 32) | class LowMemoryVideo:
    method __init__ (line 33) | def __init__(self, file_name):
    method __len__ (line 36) | def __len__(self):
    method __getitem__ (line 39) | def __getitem__(self, item):
    method __del__ (line 42) | def __del__(self):
  function split_file_name (line 46) | def split_file_name(file_name):
  function search_for_images (line 65) | def search_for_images(folder):
  function read_images (line 73) | def read_images(folder):
  class LowMemoryImageFolder (line 79) | class LowMemoryImageFolder:
    method __init__ (line 80) | def __init__(self, folder, file_list=None):
    method __len__ (line 86) | def __len__(self):
    method __getitem__ (line 89) | def __getitem__(self, item):
    method __del__ (line 92) | def __del__(self):
  class VideoData (line 96) | class VideoData:
    method __init__ (line 97) | def __init__(self, video_file, image_folder, **kwargs):
    method raw_data (line 110) | def raw_data(self):
    method set_length (line 116) | def set_length(self, length):
    method set_shape (line 119) | def set_shape(self, height, width):
    method __len__ (line 123) | def __len__(self):
    method shape (line 129) | def shape(self):
    method __getitem__ (line 136) | def __getitem__(self, item):
    method __del__ (line 145) | def __del__(self):

FILE: diffsynth/extensions/FastBlend/patch_match.py
  class PatchMatcher (line 7) | class PatchMatcher:
    method __init__ (line 8) | def __init__(
    method pad_image (line 37) | def pad_image(self, image):
    method unpad_image (line 40) | def unpad_image(self, image):
    method apply_nnf_to_image (line 43) | def apply_nnf_to_image(self, nnf, source):
    method get_patch_error (line 53) | def get_patch_error(self, source, nnf, target):
    method get_pairwise_patch_error (line 63) | def get_pairwise_patch_error(self, source, nnf):
    method get_error (line 76) | def get_error(self, source_guide, target_guide, source_style, target_s...
    method clamp_bound (line 89) | def clamp_bound(self, nnf):
    method random_step (line 94) | def random_step(self, nnf, r):
    method neighboor_step (line 100) | def neighboor_step(self, nnf, d):
    method shift_nnf (line 116) | def shift_nnf(self, nnf, d):
    method track_step (line 125) | def track_step(self, nnf, d):
    method C (line 134) | def C(self, n, m):
    method bezier_step (line 145) | def bezier_step(self, nnf, r):
    method update (line 158) | def update(self, source_guide, target_guide, source_style, target_styl...
    method propagation (line 165) | def propagation(self, source_guide, target_guide, source_style, target...
    method random_search (line 171) | def random_search(self, source_guide, target_guide, source_style, targ...
    method track (line 177) | def track(self, source_guide, target_guide, source_style, target_style...
    method iteration (line 185) | def iteration(self, source_guide, target_guide, source_style, target_s...
    method estimate_nnf (line 191) | def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
  class PyramidPatchMatcher (line 205) | class PyramidPatchMatcher:
    method __init__ (line 206) | def __init__(
    method resample_image (line 234) | def resample_image(self, images, level):
    method initialize_nnf (line 244) | def initialize_nnf(self, batch_size):
    method update_nnf (line 262) | def update_nnf(self, nnf, level):
    method apply_nnf_to_image (line 276) | def apply_nnf_to_image(self, nnf, image):
    method estimate_nnf (line 282) | def estimate_nnf(self, source_guide, target_guide, source_style):

FILE: diffsynth/extensions/FastBlend/runners/accurate.py
  class AccurateModeRunner (line 8) | class AccurateModeRunner:
    method __init__ (line 9) | def __init__(self):
    method run (line 12) | def run(self, frames_guide, frames_style, batch_size, window_size, ebs...

FILE: diffsynth/extensions/FastBlend/runners/balanced.py
  class BalancedModeRunner (line 8) | class BalancedModeRunner:
    method __init__ (line 9) | def __init__(self):
    method run (line 12) | def run(self, frames_guide, frames_style, batch_size, window_size, ebs...

FILE: diffsynth/extensions/FastBlend/runners/fast.py
  class TableManager (line 8) | class TableManager:
    method __init__ (line 9) | def __init__(self):
    method task_list (line 12) | def task_list(self, n):
    method build_remapping_table (line 34) | def build_remapping_table(self, frames_guide, frames_style, patch_matc...
    method remapping_table_to_blending_table (line 56) | def remapping_table_to_blending_table(self, table):
    method tree_query (line 66) | def tree_query(self, leftbound, rightbound):
    method process_window_sum (line 77) | def process_window_sum(self, frames_guide, blending_table, patch_match...
  class FastModeRunner (line 109) | class FastModeRunner:
    method __init__ (line 110) | def __init__(self):
    method run (line 113) | def run(self, frames_guide, frames_style, batch_size, window_size, ebs...

FILE: diffsynth/extensions/FastBlend/runners/interpolation.py
  class InterpolationModeRunner (line 8) | class InterpolationModeRunner:
    method __init__ (line 9) | def __init__(self):
    method get_index_dict (line 12) | def get_index_dict(self, index_style):
    method get_weight (line 18) | def get_weight(self, l, m, r):
    method get_task_group (line 26) | def get_task_group(self, index_style, n):
    method run (line 48) | def run(self, frames_guide, frames_style, index_style, batch_size, ebs...
  class InterpolationModeSingleFrameRunner (line 87) | class InterpolationModeSingleFrameRunner:
    method __init__ (line 88) | def __init__(self):
    method run (line 91) | def run(self, frames_guide, frames_style, index_style, batch_size, ebs...

FILE: diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
  function default_bert (line 16) | def default_bert():
  function init_tokenizer (line 23) | def init_tokenizer(bert_model_path):
  function create_vit (line 31) | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer...
  function is_url (line 49) | def is_url(url_or_filename):
  function load_checkpoint (line 53) | def load_checkpoint(model,url_or_filename):

FILE: diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
  class BLIP_Pretrain (line 13) | class BLIP_Pretrain(nn.Module):
    method __init__ (line 14) | def __init__(self,

FILE: diffsynth/extensions/ImageQualityMetric/BLIP/med.py
  class BertEmbeddings (line 44) | class BertEmbeddings(nn.Module):
    method __init__ (line 47) | def __init__(self, config):
    method forward (line 63) | def forward(
  class BertSelfAttention (line 89) | class BertSelfAttention(nn.Module):
    method __init__ (line 90) | def __init__(self, config, is_cross_attention):
    method save_attn_gradients (line 118) | def save_attn_gradients(self, attn_gradients):
    method get_attn_gradients (line 121) | def get_attn_gradients(self):
    method save_attention_map (line 124) | def save_attention_map(self, attention_map):
    method get_attention_map (line 127) | def get_attention_map(self):
    method transpose_for_scores (line 130) | def transpose_for_scores(self, x):
    method forward (line 135) | def forward(
  class BertSelfOutput (line 220) | class BertSelfOutput(nn.Module):
    method __init__ (line 221) | def __init__(self, config):
    method forward (line 227) | def forward(self, hidden_states, input_tensor):
  class BertAttention (line 234) | class BertAttention(nn.Module):
    method __init__ (line 235) | def __init__(self, config, is_cross_attention=False):
    method prune_heads (line 241) | def prune_heads(self, heads):
    method forward (line 259) | def forward(
  class BertIntermediate (line 283) | class BertIntermediate(nn.Module):
    method __init__ (line 284) | def __init__(self, config):
    method forward (line 292) | def forward(self, hidden_states):
  class BertOutput (line 298) | class BertOutput(nn.Module):
    method __init__ (line 299) | def __init__(self, config):
    method forward (line 305) | def forward(self, hidden_states, input_tensor):
  class BertLayer (line 312) | class BertLayer(nn.Module):
    method __init__ (line 313) | def __init__(self, config, layer_num):
    method forward (line 325) | def forward(
    method feed_forward_chunk (line 372) | def feed_forward_chunk(self, attention_output):
  class BertEncoder (line 378) | class BertEncoder(nn.Module):
    method __init__ (line 379) | def __init__(self, config):
    method forward (line 385) | def forward(
  class BertPooler (line 478) | class BertPooler(nn.Module):
    method __init__ (line 479) | def __init__(self, config):
    method forward (line 484) | def forward(self, hidden_states):
  class BertPredictionHeadTransform (line 493) | class BertPredictionHeadTransform(nn.Module):
    method __init__ (line 494) | def __init__(self, config):
    method forward (line 503) | def forward(self, hidden_states):
  class BertLMPredictionHead (line 510) | class BertLMPredictionHead(nn.Module):
    method __init__ (line 511) | def __init__(self, config):
    method forward (line 524) | def forward(self, hidden_states):
  class BertOnlyMLMHead (line 530) | class BertOnlyMLMHead(nn.Module):
    method __init__ (line 531) | def __init__(self, config):
    method forward (line 535) | def forward(self, sequence_output):
  class BertPreTrainedModel (line 540) | class BertPreTrainedModel(PreTrainedModel):
    method _init_weights (line 550) | def _init_weights(self, module):
  class BertModel (line 563) | class BertModel(BertPreTrainedModel):
    method __init__ (line 573) | def __init__(self, config, add_pooling_layer=True):
    method get_input_embeddings (line 586) | def get_input_embeddings(self):
    method set_input_embeddings (line 589) | def set_input_embeddings(self, value):
    method _prune_heads (line 592) | def _prune_heads(self, heads_to_prune):
    method get_extended_attention_mask (line 601) | def get_extended_attention_mask(self, attention_mask: Tensor, input_sh...
    method forward (line 662) | def forward(
  class BertLMHeadModel (line 803) | class BertLMHeadModel(BertPreTrainedModel):
    method __init__ (line 808) | def __init__(self, config):
    method get_output_embeddings (line 816) | def get_output_embeddings(self):
    method set_output_embeddings (line 819) | def set_output_embeddings(self, new_embeddings):
    method forward (line 822) | def forward(
    method prepare_inputs_for_generation (line 924) | def prepare_inputs_for_generation(self, input_ids, past=None, attentio...
    method _reorder_cache (line 943) | def _reorder_cache(self, past, beam_idx):

FILE: diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
  class Mlp (line 19) | class Mlp(nn.Module):
    method __init__ (line 22) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 31) | def forward(self, x):
  class Attention (line 40) | class Attention(nn.Module):
    method __init__ (line 41) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, at...
    method save_attn_gradients (line 54) | def save_attn_gradients(self, attn_gradients):
    method get_attn_gradients (line 57) | def get_attn_gradients(self):
    method save_attention_map (line 60) | def save_attention_map(self, attention_map):
    method get_attention_map (line 63) | def get_attention_map(self):
    method forward (line 66) | def forward(self, x, register_hook=False):
  class Block (line 85) | class Block(nn.Module):
    method __init__ (line 87) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
    method forward (line 103) | def forward(self, x, register_hook=False):
  class VisionTransformer (line 109) | class VisionTransformer(nn.Module):
    method __init__ (line 114) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
    method _init_weights (line 163) | def _init_weights(self, m):
    method no_weight_decay (line 173) | def no_weight_decay(self):
    method forward (line 176) | def forward(self, x, register_blk=-1):
    method load_pretrained (line 193) | def load_pretrained(self, checkpoint_path, prefix=''):
  function _load_weights (line 198) | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix...
  function interpolate_pos_embed (line 277) | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):

FILE: diffsynth/extensions/ImageQualityMetric/__init__.py
  function download_preference_model (line 136) | def download_preference_model(model_name: preference_model_id, cache_dir...
  function load_preference_model (line 144) | def load_preference_model(model_name: preference_model_id, device = "cud...

FILE: diffsynth/extensions/ImageQualityMetric/aesthetic.py
  class MLP (line 10) | class MLP(torch.nn.Module):
    method __init__ (line 11) | def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "av...
    method forward (line 31) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method training_step (line 34) | def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
    method validation_step (line 41) | def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
    method configure_optimizers (line 48) | def configure_optimizers(self) -> torch.optim.Optimizer:
  class AestheticScore (line 52) | class AestheticScore(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, device: torch.device, path: str = MODEL_PATHS):
    method _calculate_score (line 76) | def _calculate_score(self, image: torch.Tensor) -> float:
    method score (line 96) | def score(self, images: Union[str, List[str], Image.Image, List[Image....

FILE: diffsynth/extensions/ImageQualityMetric/clip.py
  class CLIPScore (line 7) | class CLIPScore(torch.nn.Module):
    method __init__ (line 8) | def __init__(self, device: torch.device, path: str = MODEL_PATHS):
    method _calculate_score (line 44) | def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
    method score (line 67) | def score(self, images: Union[str, List[str], Image.Image, List[Image....

FILE: diffsynth/extensions/ImageQualityMetric/config.py
  function get_model_path (line 8) | def get_model_path(model_name):

FILE: diffsynth/extensions/ImageQualityMetric/hps.py
  class HPScore_v2 (line 9) | class HPScore_v2(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, device: torch.device, path: str = MODEL_PATHS, mode...
    method _calculate_score (line 62) | def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
    method score (line 85) | def score(self, images: Union[str, List[str], Image.Image, List[Image....

FILE: diffsynth/extensions/ImageQualityMetric/imagereward.py
  function _convert_image_to_rgb (line 12) | def _convert_image_to_rgb(image):
  function _transform (line 15) | def _transform(n_px):
  class MLP (line 24) | class MLP(torch.nn.Module):
    method __init__ (line 25) | def __init__(self, input_size):
    method forward (line 51) | def forward(self, input):
  class ImageReward (line 54) | class ImageReward(torch.nn.Module):
    method __init__ (line 55) | def __init__(self, med_config, device='cpu', bert_model_path=""):
    method score_grad (line 66) | def score_grad(self, prompt_ids, prompt_attention_mask, image):
    method score (line 91) | def score(self, images: Union[str, List[str], Image.Image, List[Image....
    method _calculate_score (line 125) | def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch....
    method inference_rank (line 150) | def inference_rank(self, prompt: str, generations_list: List[Union[str...
  class ImageRewardScore (line 190) | class ImageRewardScore(torch.nn.Module):
    method __init__ (line 191) | def __init__(self, device: Union[str, torch.device], path: str = MODEL...
    method score (line 202) | def score(self, images: Union[str, List[str], Image.Image, List[Image....

FILE: diffsynth/extensions/ImageQualityMetric/mps.py
  class MPScore (line 27) | class MPScore(torch.nn.Module):
    method __init__ (line 28) | def __init__(self, device: Union[str, torch.device], path: str = MODEL...
    method _calculate_score (line 45) | def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
    method score (line 99) | def score(self, images: Union[str, List[str], Image.Image, List[Image....

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
  class MultimodalCfg (line 45) | class MultimodalCfg(CLIPTextCfg):
  function _build_text_decoder_tower (line 53) | def _build_text_decoder_tower(
  class CoCa (line 79) | class CoCa(nn.Module):
    method __init__ (line 80) | def __init__(
    method set_grad_checkpointing (line 126) | def set_grad_checkpointing(self, enable=True):
    method _encode_image (line 131) | def _encode_image(self, images, normalize=True):
    method _encode_text (line 136) | def _encode_text(self, text, normalize=True, embed_cls=True):
    method encode_image (line 142) | def encode_image(self, images, normalize=True):
    method encode_text (line 146) | def encode_text(self, text, normalize=True, embed_cls=True):
    method forward (line 150) | def forward(self, image, text, embed_cls=True, image_latent=None, imag...
    method generate (line 167) | def generate(
    method _generate_beamsearch (line 290) | def _generate_beamsearch(
  function prepare_inputs_for_generation (line 439) | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **...

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
  function _natural_key (line 29) | def _natural_key(string_):
  function _rescan_model_configs (line 33) | def _rescan_model_configs():
  function list_models (line 57) | def list_models():
  function add_model_config (line 62) | def add_model_config(path):
  function get_model_config (line 70) | def get_model_config(model_name):
  function get_tokenizer (line 77) | def get_tokenizer(model_name, open_clip_bpe_path=None):
  function load_state_dict (line 87) | def load_state_dict(checkpoint_path: str, map_location='cpu'):
  function load_checkpoint (line 98) | def load_checkpoint(model, checkpoint_path, strict=True):
  function create_model (line 108) | def create_model(
  function create_loss (line 244) | def create_loss(args):
  class MLP (line 274) | class MLP(torch.nn.Module):
    method __init__ (line 275) | def __init__(self, input_size):
    method forward (line 289) | def forward(self, x):
  function create_model_and_transforms (line 309) | def create_model_and_transforms(
  function create_model_from_pretrained (line 394) | def create_model_from_pretrained(

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
  class BaseModelOutput (line 21) | class BaseModelOutput:
  class PretrainedConfig (line 25) | class PretrainedConfig:
  function _camel2snake (line 32) | def _camel2snake(s):
  function register_pooler (line 40) | def register_pooler(cls):
  class MeanPooler (line 47) | class MeanPooler(nn.Module):
    method forward (line 50) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class MaxPooler (line 56) | class MaxPooler(nn.Module):
    method forward (line 59) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class ClsPooler (line 65) | class ClsPooler(nn.Module):
    method __init__ (line 68) | def __init__(self, use_pooler_output=True):
    method forward (line 73) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class HFTextEncoder (line 83) | class HFTextEncoder(nn.Module):
    method __init__ (line 87) | def __init__(
    method forward (line 137) | def forward(self, x: TensorType):
    method lock (line 154) | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
    method set_grad_checkpointing (line 172) | def set_grad_checkpointing(self, enable=True):
    method init_parameters (line 175) | def init_parameters(self):

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
  function gather_features (line 20) | def gather_features(
  class ClipLoss (line 67) | class ClipLoss(nn.Module):
    method __init__ (line 69) | def __init__(
    method get_ground_truth (line 90) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
    method get_logits (line 103) | def get_logits(self, image_features, text_features, logit_scale):
    method forward (line 121) | def forward(self, image_features, text_features, logit_scale, output_d...
  class PreferenceLoss (line 133) | class PreferenceLoss(nn.Module):
    method forward (line 135) | def forward(self, logits_per_image, num_images, labels):
  class HPSLoss (line 143) | class HPSLoss(nn.Module):
    method forward (line 145) | def forward(self, text_logits, labels):
  class RankingLoss (line 171) | class RankingLoss(nn.Module):
    method forward (line 173) | def forward(self, logits_per_image, num_images, labels, margin = 1.0):
  class CoCaLoss (line 192) | class CoCaLoss(ClipLoss):
    method __init__ (line 193) | def __init__(
    method forward (line 218) | def forward(self, image_features, text_features, logits, labels, logit...
  class DistillClipLoss (line 234) | class DistillClipLoss(ClipLoss):
    method dist_loss (line 236) | def dist_loss(self, teacher_logits, student_logits):
    method forward (line 239) | def forward(

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/model.py
  class CLIPVisionCfg (line 24) | class CLIPVisionCfg:
  class CLIPTextCfg (line 49) | class CLIPTextCfg:
  function get_cast_dtype (line 66) | def get_cast_dtype(precision: str):
  function _build_vision_tower (line 75) | def _build_vision_tower(
  function _build_text_tower (line 137) | def _build_text_tower(
  class CLIP (line 176) | class CLIP(nn.Module):
    method __init__ (line 179) | def __init__(
    method lock_image_tower (line 203) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method lock_text_tower (line 207) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
    method set_grad_checkpointing (line 224) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 228) | def encode_image(self, image, normalize: bool = False):
    method encode_text (line 232) | def encode_text(self, text, normalize: bool = False):
    method forward (line 246) | def forward(self, image, text):
  class CustomTextCLIP (line 258) | class CustomTextCLIP(nn.Module):
    method __init__ (line 261) | def __init__(
    method lock_image_tower (line 276) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method lock_text_tower (line 280) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
    method set_grad_checkpointing (line 284) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 288) | def encode_image(self, image, normalize: bool = False):
    method encode_text (line 292) | def encode_text(self, text, normalize: bool = False):
    method forward (line 296) | def forward(self, image, text):
  function convert_weights_to_lp (line 308) | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
  function convert_to_custom_text_state_dict (line 336) | def convert_to_custom_text_state_dict(state_dict: dict):
  function build_model_from_openai_state_dict (line 354) | def build_model_from_openai_state_dict(
  function trace_model (line 414) | def trace_model(model, batch_size=256, device=torch.device('cpu')):
  function resize_pos_embed (line 430) | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', ...

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
  class Bottleneck (line 10) | class Bottleneck(nn.Module):
    method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 42) | def forward(self, x: torch.Tensor):
  class AttentionPool2d (line 58) | class AttentionPool2d(nn.Module):
    method __init__ (line 59) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o...
    method forward (line 68) | def forward(self, x):
  class ModifiedResNet (line 95) | class ModifiedResNet(nn.Module):
    method __init__ (line 103) | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
    method _make_layer (line 132) | def _make_layer(self, planes, blocks, stride=1):
    method init_parameters (line 141) | def init_parameters(self):
    method lock (line 154) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 162) | def set_grad_checkpointing(self, enable=True):
    method stem (line 166) | def stem(self, x):
    method forward (line 173) | def forward(self, x):

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
  function list_openai_models (line 18) | def list_openai_models() -> List[str]:
  function load_openai_model (line 23) | def load_openai_model(

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
  function _pcfg (line 21) | def _pcfg(url='', hf_hub='', mean=None, std=None):
  function _clean_tag (line 235) | def _clean_tag(tag: str):
  function list_pretrained (line 240) | def list_pretrained(as_str: bool = False):
  function list_pretrained_models_by_tag (line 247) | def list_pretrained_models_by_tag(tag: str):
  function list_pretrained_tags_by_model (line 257) | def list_pretrained_tags_by_model(model: str):
  function is_pretrained_cfg (line 265) | def is_pretrained_cfg(model: str, tag: str):
  function get_pretrained_cfg (line 271) | def get_pretrained_cfg(model: str, tag: str):
  function get_pretrained_url (line 278) | def get_pretrained_url(model: str, tag: str):
  function download_pretrained_from_url (line 283) | def download_pretrained_from_url(
  function has_hf_hub (line 329) | def has_hf_hub(necessary=False):
  function download_pretrained_from_hf (line 337) | def download_pretrained_from_hf(
  function download_pretrained (line 348) | def download_pretrained(

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
  function save_config_for_hf (line 27) | def save_config_for_hf(
  function save_for_hf (line 45) | def save_for_hf(
  function push_to_hf_hub (line 65) | def push_to_hf_hub(
  function push_pretrained_to_hf_hub (line 124) | def push_pretrained_to_hf_hub(
  function generate_readme (line 163) | def generate_readme(model_card: dict, model_name: str):

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
  class TimmModel (line 28) | class TimmModel(nn.Module):
    method __init__ (line 33) | def __init__(
    method lock (line 85) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 118) | def set_grad_checkpointing(self, enable=True):
    method forward (line 124) | def forward(self, x):

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
  function default_bpe (line 21) | def default_bpe():
  function bytes_to_unicode (line 29) | def bytes_to_unicode():
  function get_pairs (line 51) | def get_pairs(word):
  function basic_clean (line 63) | def basic_clean(text):
  function whitespace_clean (line 69) | def whitespace_clean(text):
  class SimpleTokenizer (line 75) | class SimpleTokenizer(object):
    method __init__ (line 76) | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
    method bpe (line 101) | def bpe(self, token):
    method encode (line 142) | def encode(self, text):
    method decode (line 150) | def decode(self, tokens):
    method __call__ (line 155) | def __call__(self, texts: Union[str, List[str]], context_length: int =...
  class HFTokenizer (line 188) | class HFTokenizer:
    method __init__ (line 191) | def __init__(self, tokenizer_name: str):
    method save_pretrained (line 195) | def save_pretrained(self, dest):
    method __call__ (line 198) | def __call__(self, texts: Union[str, List[str]], context_length: int =...

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
  class AugmentationCfg (line 16) | class AugmentationCfg:
  class ResizeMaxSize (line 26) | class ResizeMaxSize(nn.Module):
    method __init__ (line 28) | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, ...
    method forward (line 37) | def forward(self, img):
  function _convert_to_rgb_or_rgba (line 52) | def _convert_to_rgb_or_rgba(image):
  class MaskAwareNormalize (line 66) | class MaskAwareNormalize(nn.Module):
    method __init__ (line 67) | def __init__(self, mean, std):
    method forward (line 71) | def forward(self, tensor):
  function image_transform (line 77) | def image_transform(

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
  class LayerNormFp32 (line 13) | class LayerNormFp32(nn.LayerNorm):
    method forward (line 16) | def forward(self, x: torch.Tensor):
  class LayerNorm (line 22) | class LayerNorm(nn.LayerNorm):
    method forward (line 25) | def forward(self, x: torch.Tensor):
  class QuickGELU (line 31) | class QuickGELU(nn.Module):
    method forward (line 33) | def forward(self, x: torch.Tensor):
  class LayerScale (line 37) | class LayerScale(nn.Module):
    method __init__ (line 38) | def __init__(self, dim, init_values=1e-5, inplace=False):
    method forward (line 43) | def forward(self, x):
  class PatchDropout (line 47) | class PatchDropout(nn.Module):
    method __init__ (line 52) | def __init__(self, prob, exclude_first_token=True):
    method forward (line 58) | def forward(self, x):
  class Attention (line 87) | class Attention(nn.Module):
    method __init__ (line 88) | def __init__(
    method forward (line 127) | def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
  class AttentionalPooler (line 163) | class AttentionalPooler(nn.Module):
    method __init__ (line 164) | def __init__(
    method forward (line 178) | def forward(self, x: torch.Tensor):
    method _repeat (line 185) | def _repeat(self, query, N: int):
  class ResidualAttentionBlock (line 189) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 190) | def __init__(
    method attention (line 217) | def attention(
    method forward (line 232) | def forward(
  class CustomResidualAttentionBlock (line 247) | class CustomResidualAttentionBlock(nn.Module):
    method __init__ (line 248) | def __init__(
    method forward (line 282) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
  class Transformer (line 288) | class Transformer(nn.Module):
    method __init__ (line 289) | def __init__(
    method get_cast_dtype (line 310) | def get_cast_dtype(self) -> torch.dtype:
    method forward (line 313) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
  class VisionTransformer (line 323) | class VisionTransformer(nn.Module):
    method __init__ (line 326) | def __init__(
    method lock (line 395) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method init_parameters (line 428) | def init_parameters(self):
    method set_grad_checkpointing (line 449) | def set_grad_checkpointing(self, enable=True):
    method _global_pool (line 452) | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.T...
    method forward (line 458) | def forward(self, x: torch.Tensor, skip_pool: bool = False):
  class TextTransformer (line 507) | class TextTransformer(nn.Module):
    method __init__ (line 510) | def __init__(
    method init_parameters (line 558) | def init_parameters(self):
    method set_grad_checkpointing (line 577) | def set_grad_checkpointing(self, enable=True):
    method build_attention_mask (line 580) | def build_attention_mask(self):
    method build_cls_mask (line 588) | def build_cls_mask(self, text, cast_dtype: torch.dtype):
    method _repeat (line 597) | def _repeat(self, t, N: int):
    method forward (line 600) | def forward(self, text):
  class MultimodalTransformer (line 635) | class MultimodalTransformer(Transformer):
    method __init__ (line 636) | def __init__(
    method init_parameters (line 677) | def init_parameters(self):
    method build_attention_mask (line 695) | def build_attention_mask(self):
    method forward (line 703) | def forward(self, image_embs, text_embs):
    method set_grad_checkpointing (line 726) | def set_grad_checkpointing(self, enable=True):

FILE: diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
  function freeze_batch_norm_2d (line 8) | def freeze_batch_norm_2d(module, module_match={}, name=''):
  function _ntuple (line 48) | def _ntuple(n):

FILE: diffsynth/extensions/ImageQualityMetric/pickscore.py
  class PickScore (line 8) | class PickScore(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, device: Union[str, torch.device], path: str = MODEL...
    method _calculate_score (line 22) | def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: ...
    method score (line 58) | def score(self, images: Union[str, List[str], Image.Image, List[Image....

FILE: diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
  class BaseModelConfig (line 6) | class BaseModelConfig:

FILE: diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
  class XCLIPModel (line 17) | class XCLIPModel(HFCLIPModel):
    method __init__ (line 18) | def __init__(self, config: CLIPConfig):
    method get_text_features (line 21) | def get_text_features(
    method get_image_features (line 61) | def get_image_features(
  class ClipModelConfig (line 93) | class ClipModelConfig(BaseModelConfig):
  class CLIPModel (line 98) | class CLIPModel(nn.Module):
    method __init__ (line 99) | def __init__(self, ckpt, config_file=False):
    method get_text_features (line 110) | def get_text_features(self, *args, **kwargs):
    method get_image_features (line 113) | def get_image_features(self, *args, **kwargs):
    method forward (line 116) | def forward(self, text_inputs=None, image_inputs=None, condition_input...
    method logit_scale (line 141) | def logit_scale(self):
    method save (line 144) | def save(self, path):

FILE: diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
  function exists (line 8) | def exists(val):
  function default (line 11) | def default(val, d):
  class LayerNorm (line 18) | class LayerNorm(nn.Module):
    method __init__ (line 19) | def __init__(self, dim):
    method forward (line 24) | def forward(self, x):
  class Residual (line 30) | class Residual(nn.Module):
    method __init__ (line 31) | def __init__(self, fn):
    method forward (line 35) | def forward(self, x, *args, **kwargs):
  class RotaryEmbedding (line 43) | class RotaryEmbedding(nn.Module):
    method __init__ (line 44) | def __init__(self, dim):
    method forward (line 49) | def forward(self, max_seq_len, *, device):
  function rotate_half (line 55) | def rotate_half(x):
  function apply_rotary_pos_emb (line 61) | def apply_rotary_pos_emb(pos, t):
  class SwiGLU (line 69) | class SwiGLU(nn.Module):
    method forward (line 70) | def forward(self, x):
  class ParallelTransformerBlock (line 78) | class ParallelTransformerBlock(nn.Module):
    method __init__ (line 79) | def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
    method get_rotary_embedding (line 102) | def get_rotary_embedding(self, n, device):
    method forward (line 110) | def forward(self, x, attn_mask=None):
  class CrossAttention (line 172) | class CrossAttention(nn.Module):
    method __init__ (line 173) | def __init__(
    method forward (line 207) | def forward(self, x, context, mask):
  class Cross_model (line 261) | class Cross_model(nn.Module):
    method __init__ (line 262) | def __init__(
    method forward (line 281) | def forward(

FILE: diffsynth/extensions/RIFE/__init__.py
  function warp (line 8) | def warp(tenInput, tenFlow, device):
  function conv (line 26) | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dila...
  class IFBlock (line 34) | class IFBlock(nn.Module):
    method __init__ (line 35) | def __init__(self, in_planes, c=64):
    method forward (line 45) | def forward(self, x, flow, scale=1):
  class IFNet (line 60) | class IFNet(nn.Module):
    method __init__ (line 61) | def __init__(self, **kwargs):
    method forward (line 68) | def forward(self, x, scale_list=[4, 2, 1], training=False):
    method state_dict_converter (line 103) | def state_dict_converter():
  class IFNetStateDictConverter (line 107) | class IFNetStateDictConverter:
    method __init__ (line 108) | def __init__(self):
    method from_diffusers (line 111) | def from_diffusers(self, state_dict):
    method from_civitai (line 115) | def from_civitai(self, state_dict):
  class RIFEInterpolater (line 119) | class RIFEInterpolater:
    method __init__ (line 120) | def __init__(self, model, device="cuda"):
    method from_model_manager (line 127) | def from_model_manager(model_manager):
    method process_image (line 130) | def process_image(self, image):
    method process_images (line 139) | def process_images(self, images):
    method decode_images (line 144) | def decode_images(self, images):
    method add_interpolated_images (line 149) | def add_interpolated_images(self, images, interpolated_images):
    method interpolate_ (line 159) | def interpolate_(self, images, scale=1.0):
    method interpolate (line 171) | def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, pro...
  class RIFESmoother (line 200) | class RIFESmoother(RIFEInterpolater):
    method __init__ (line 201) | def __init__(self, model, device="cuda"):
    method from_model_manager (line 205) | def from_model_manager(model_manager):
    method process_tensors (line 208) | def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
    method __call__ (line 220) | def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=...

FILE: diffsynth/models/attention.py
  function low_version_attention (line 5) | def low_version_attention(query, key, value, attn_bias=None):
  class Attention (line 15) | class Attention(torch.nn.Module):
    method __init__ (line 17) | def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=Fal...
    method interact_with_ipadapter (line 29) | def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=...
    method torch_forward (line 37) | def torch_forward(self, hidden_states, encoder_hidden_states=None, att...
    method xformers_forward (line 64) | def xformers_forward(self, hidden_states, encoder_hidden_states=None, ...
    method forward (line 88) | def forward(self, hidden_states, encoder_hidden_states=None, attn_mask...

FILE: diffsynth/models/cog_dit.py
  class CogPatchify (line 11) | class CogPatchify(torch.nn.Module):
    method __init__ (line 12) | def __init__(self, dim_in, dim_out, patch_size) -> None:
    method forward (line 16) | def forward(self, hidden_states):
  class CogAdaLayerNorm (line 23) | class CogAdaLayerNorm(torch.nn.Module):
    method __init__ (line 24) | def __init__(self, dim, dim_cond, single=False):
    method forward (line 31) | def forward(self, hidden_states, prompt_emb, emb):
  class CogDiTBlock (line 45) | class CogDiTBlock(torch.nn.Module):
    method __init__ (line 46) | def __init__(self, dim, dim_cond, num_heads):
    method apply_rotary_emb (line 61) | def apply_rotary_emb(self, x, freqs_cis):
    method process_qkv (line 72) | def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
    method forward (line 80) | def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
  class CogDiT (line 108) | class CogDiT(torch.nn.Module):
    method __init__ (line 109) | def __init__(self):
    method get_resize_crop_region_for_grid (line 120) | def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
    method get_3d_rotary_pos_embed (line 138) | def get_3d_rotary_pos_embed(
    method prepare_rotary_positional_embeddings (line 202) | def prepare_rotary_positional_embeddings(
    method unpatchify (line 230) | def unpatchify(self, hidden_states, height, width):
    method build_mask (line 235) | def build_mask(self, T, H, W, dtype, device, is_bound):
    method tiled_forward (line 255) | def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size...
    method forward (line 286) | def forward(self, hidden_states, timestep, prompt_emb, image_rotary_em...
    method state_dict_converter (line 328) | def state_dict_converter():
    method from_pretrained (line 333) | def from_pretrained(file_path, torch_dtype=torch.bfloat16):
  class CogDiTStateDictConverter (line 342) | class CogDiTStateDictConverter:
    method __init__ (line 343) | def __init__(self):
    method from_diffusers (line 347) | def from_diffusers(self, state_dict):
    method from_civitai (line 407) | def from_civitai(self, state_dict):

FILE: diffsynth/models/cog_vae.py
  class Downsample3D (line 7) | class Downsample3D(torch.nn.Module):
    method __init__ (line 8) | def __init__(
    method forward (line 22) | def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
  class Upsample3D (line 57) | class Upsample3D(torch.nn.Module):
    method __init__ (line 58) | def __init__(
    method forward (line 71) | def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Ten...
  class CogVideoXSpatialNorm3D (line 103) | class CogVideoXSpatialNorm3D(torch.nn.Module):
    method __init__ (line 104) | def __init__(self, f_channels, zq_channels, groups):
    method forward (line 111) | def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
  class Resnet3DBlock (line 128) | class Resnet3DBlock(torch.nn.Module):
    method __init__ (line 129) | def __init__(self, in_channels, out_channels, spatial_norm_dim, groups...
    method forward (line 152) | def forward(self, hidden_states, zq):
  class CachedConv3d (line 169) | class CachedConv3d(torch.nn.Conv3d):
    method __init__ (line 170) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
    method clear_cache (line 175) | def clear_cache(self):
    method forward (line 179) | def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
  class CogVAEDecoder (line 189) | class CogVAEDecoder(torch.nn.Module):
    method __init__ (line 190) | def __init__(self):
    method forward (line 224) | def forward(self, sample):
    method decode_video (line 238) | def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_st...
    method decode_small_video (line 254) | def decode_small_video(self, sample):
    method state_dict_converter (line 273) | def state_dict_converter():
  class CogVAEEncoder (line 278) | class CogVAEEncoder(torch.nn.Module):
    method __init__ (line 279) | def __init__(self):
    method forward (line 309) | def forward(self, sample):
    method encode_video (line 323) | def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_st...
    method encode_small_video (line 339) | def encode_small_video(self, sample):
    method state_dict_converter (line 358) | def state_dict_converter():
  class CogVAEEncoderStateDictConverter (line 363) | class CogVAEEncoderStateDictConverter:
    method __init__ (line 364) | def __init__(self):
    method from_diffusers (line 368) | def from_diffusers(self, state_dict):
    method from_civitai (line 435) | def from_civitai(self, state_dict):
  class CogVAEDecoderStateDictConverter (line 440) | class CogVAEDecoderStateDictConverter:
    method __init__ (line 441) | def __init__(self):
    method from_diffusers (line 445) | def from_diffusers(self, state_dict):
    method from_civitai (line 516) | def from_civitai(self, state_dict):

FILE: diffsynth/models/downloader.py
  function download_from_modelscope (line 9) | def download_from_modelscope(model_id, origin_file_path, local_dir):
  function download_from_huggingface (line 24) | def download_from_huggingface(model_id, origin_file_path, local_dir):
  function download_customized_models (line 53) | def download_customized_models(
  function download_models (line 72) | def download_models(

FILE: diffsynth/models/flux_controlnet.py
  class FluxControlNet (line 8) | class FluxControlNet(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5...
    method prepare_image_ids (line 29) | def prepare_image_ids(self, latents):
    method patchify (line 46) | def patchify(self, hidden_states):
    method align_res_stack_to_original_blocks (line 51) | def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hi...
    method forward (line 59) | def forward(
    method state_dict_converter (line 106) | def state_dict_converter():
    method quantize (line 109) | def quantize(self):
  class FluxControlNetStateDictConverter (line 212) | class FluxControlNetStateDictConverter:
    method __init__ (line 213) | def __init__(self):
    method from_diffusers (line 216) | def from_diffusers(self, state_dict):
    method from_civitai (line 326) | def from_civitai(self, state_dict):

FILE: diffsynth/models/flux_dit.py
  function interact_with_ipadapter (line 7) | def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
  class RoPEEmbedding (line 15) | class RoPEEmbedding(torch.nn.Module):
    method __init__ (line 16) | def __init__(self, dim, theta, axes_dim):
    method rope (line 23) | def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
    method forward (line 39) | def forward(self, ids):
  class FluxJointAttention (line 46) | class FluxJointAttention(torch.nn.Module):
    method __init__ (line 47) | def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
    method apply_rope (line 66) | def apply_rope(self, xq, xk, freqs_cis):
    method forward (line 73) | def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ...
  class FluxJointTransformerBlock (line 109) | class FluxJointTransformerBlock(torch.nn.Module):
    method __init__ (line 110) | def __init__(self, dim, num_attention_heads):
    method forward (line 132) | def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary...
  class FluxSingleAttention (line 153) | class FluxSingleAttention(torch.nn.Module):
    method __init__ (line 154) | def __init__(self, dim_a, dim_b, num_heads, head_dim):
    method apply_rope (line 165) | def apply_rope(self, xq, xk, freqs_cis):
    method forward (line 173) | def forward(self, hidden_states, image_rotary_emb):
  class AdaLayerNormSingle (line 190) | class AdaLayerNormSingle(torch.nn.Module):
    method __init__ (line 191) | def __init__(self, dim):
    method forward (line 198) | def forward(self, x, emb):
  class FluxSingleTransformerBlock (line 206) | class FluxSingleTransformerBlock(torch.nn.Module):
    method __init__ (line 207) | def __init__(self, dim, num_attention_heads):
    method apply_rope (line 221) | def apply_rope(self, xq, xk, freqs_cis):
    method process_attention (line 229) | def process_attention(self, hidden_states, image_rotary_emb, attn_mask...
    method forward (line 246) | def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary...
  class AdaLayerNormContinuous (line 263) | class AdaLayerNormContinuous(torch.nn.Module):
    method __init__ (line 264) | def __init__(self, dim):
    method forward (line 270) | def forward(self, x, conditioning):
  class FluxDiT (line 278) | class FluxDiT(torch.nn.Module):
    method __init__ (line 279) | def __init__(self, disable_guidance_embedder=False):
    method patchify (line 295) | def patchify(self, hidden_states):
    method unpatchify (line 300) | def unpatchify(self, hidden_states, height, width):
    method prepare_image_ids (line 305) | def prepare_image_ids(self, latents):
    method tiled_forward (line 322) | def tiled_forward(
    method construct_mask (line 341) | def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
    method process_entity_masks (line 376) | def process_entity_masks(self, hidden_states, prompt_emb, entity_promp...
    method forward (line 405) | def forward(
    method quantize (line 474) | def quantize(self):
    method state_dict_converter (line 553) | def state_dict_converter():
  class FluxDiTStateDictConverter (line 557) | class FluxDiTStateDictConverter:
    method __init__ (line 558) | def __init__(self):
    method from_diffusers (line 561) | def from_diffusers(self, state_dict):
    method from_civitai (line 662) | def from_civitai(self, state_dict):

FILE: diffsynth/models/flux_ipadapter.py
  class MLPProjModel (line 7) | class MLPProjModel(torch.nn.Module):
    method __init__ (line 8) | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num...
    method forward (line 21) | def forward(self, id_embeds):
  class IpAdapterModule (line 27) | class IpAdapterModule(torch.nn.Module):
    method __init__ (line 28) | def __init__(self, num_attention_heads, attention_head_dim, input_dim):
    method forward (line 38) | def forward(self, hidden_states):
  class FluxIpAdapter (line 50) | class FluxIpAdapter(torch.nn.Module):
    method __init__ (line 51) | def __init__(self, num_attention_heads=24, attention_head_dim=128, cro...
    method set_adapter (line 57) | def set_adapter(self):
    method forward (line 60) | def forward(self, hidden_states, scale=1.0):
    method state_dict_converter (line 75) | def state_dict_converter():
  class FluxIpAdapterStateDictConverter (line 79) | class FluxIpAdapterStateDictConverter:
    method __init__ (line 80) | def __init__(self):
    method from_diffusers (line 83) | def from_diffusers(self, state_dict):
    method from_civitai (line 93) | def from_civitai(self, state_dict):

FILE: diffsynth/models/flux_text_encoder.py
  class FluxTextEncoder2 (line 7) | class FluxTextEncoder2(T5EncoderModel):
    method __init__ (line 8) | def __init__(self, config):
    method forward (line 12) | def forward(self, input_ids):
    method state_dict_converter (line 18) | def state_dict_converter():
  class FluxTextEncoder2StateDictConverter (line 23) | class FluxTextEncoder2StateDictConverter():
    method __init__ (line 24) | def __init__(self):
    method from_diffusers (line 27) | def from_diffusers(self, state_dict):
    method from_civitai (line 31) | def from_civitai(self, state_dict):

FILE: diffsynth/models/flux_vae.py
  class FluxVAEEncoder (line 5) | class FluxVAEEncoder(SD3VAEEncoder):
    method __init__ (line 6) | def __init__(self):
    method state_dict_converter (line 12) | def state_dict_converter():
  class FluxVAEDecoder (line 16) | class FluxVAEDecoder(SD3VAEDecoder):
    method __init__ (line 17) | def __init__(self):
    method state_dict_converter (line 23) | def state_dict_converter():
  class FluxVAEEncoderStateDictConverter (line 27) | class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
    method __init__ (line 28) | def __init__(self):
    method from_civitai (line 31) | def from_civitai(self, state_dict):
  class FluxVAEDecoderStateDictConverter (line 151) | class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
    method __init__ (line 152) | def __init__(self):
    method from_civitai (line 155) | def from_civitai(self, state_dict):

FILE: diffsynth/models/hunyuan_dit.py
  class HunyuanDiTRotaryEmbedding (line 7) | class HunyuanDiTRotaryEmbedding(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=T...
    method reshape_for_broadcast (line 16) | def reshape_for_broadcast(self, freqs_cis, x):
    method rotate_half (line 21) | def rotate_half(self, x):
    method apply_rotary_emb (line 25) | def apply_rotary_emb(self, xq, xk, freqs_cis):
    method forward (line 34) | def forward(self, q, k, v, freqs_cis_img, to_cache=False):
  class FP32_Layernorm (line 55) | class FP32_Layernorm(torch.nn.LayerNorm):
    method forward (line 56) | def forward(self, inputs):
  class FP32_SiLU (line 61) | class FP32_SiLU(torch.nn.SiLU):
    method forward (line 62) | def forward(self, inputs):
  class HunyuanDiTFinalLayer (line 67) | class HunyuanDiTFinalLayer(torch.nn.Module):
    method __init__ (line 68) | def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_s...
    method modulate (line 77) | def modulate(self, x, shift, scale):
    method forward (line 80) | def forward(self, hidden_states, condition_emb):
  class HunyuanDiTBlock (line 87) | class HunyuanDiTBlock(torch.nn.Module):
    method __init__ (line 89) | def __init__(
    method forward (line 118) | def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img...
  class AttentionPool (line 140) | class AttentionPool(torch.nn.Module):
    method __init__ (line 141) | def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
    method forward (line 150) | def forward(self, x):
  class PatchEmbed (line 176) | class PatchEmbed(torch.nn.Module):
    method __init__ (line 177) | def __init__(
    method forward (line 187) | def forward(self, x):
  function timestep_embedding (line 193) | def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
  class TimestepEmbedder (line 213) | class TimestepEmbedder(torch.nn.Module):
    method __init__ (line 214) | def __init__(self, hidden_size=1408, frequency_embedding_size=256):
    method forward (line 223) | def forward(self, t):
  class HunyuanDiT (line 229) | class HunyuanDiT(torch.nn.Module):
    method __init__ (line 230) | def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4...
    method prepare_text_emb (line 262) | def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_...
    method prepare_extra_emb (line 271) | def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, ba...
    method unpatchify (line 291) | def unpatchify(self, x, h, w):
    method build_mask (line 294) | def build_mask(self, data, is_bound):
    method tiled_block_forward (line 311) | def tiled_block_forward(self, block, hidden_states, condition_emb, tex...
    method forward (line 348) | def forward(
    method state_dict_converter (line 402) | def state_dict_converter():
  class HunyuanDiTStateDictConverter (line 407) | class HunyuanDiTStateDictConverter():
    method __init__ (line 408) | def __init__(self):
    method from_diffusers (line 411) | def from_diffusers(self, state_dict):
    method from_civitai (line 450) | def from_civitai(self, state_dict):

FILE: diffsynth/models/hunyuan_dit_text_encoder.py
  class HunyuanDiTCLIPTextEncoder (line 6) | class HunyuanDiTCLIPTextEncoder(BertModel):
    method __init__ (line 7) | def __init__(self):
    method forward (line 43) | def forward(self, input_ids, attention_mask, clip_skip=1):
    method state_dict_converter (line 83) | def state_dict_converter():
  class HunyuanDiTT5TextEncoder (line 88) | class HunyuanDiTT5TextEncoder(T5EncoderModel):
    method __init__ (line 89) | def __init__(self):
    method forward (line 123) | def forward(self, input_ids, attention_mask, clip_skip=1):
    method state_dict_converter (line 136) | def state_dict_converter():
  class HunyuanDiTCLIPTextEncoderStateDictConverter (line 141) | class HunyuanDiTCLIPTextEncoderStateDictConverter():
    method __init__ (line 142) | def __init__(self):
    method from_diffusers (line 145) | def from_diffusers(self, state_dict):
    method from_civitai (line 149) | def from_civitai(self, state_dict):
  class HunyuanDiTT5TextEncoderStateDictConverter (line 153) | class HunyuanDiTT5TextEncoderStateDictConverter():
    method __init__ (line 154) | def __init__(self):
    method from_diffusers (line 157) | def from_diffusers(self, state_dict):
    method from_civitai (line 162) | def from_civitai(self, state_dict):

FILE: diffsynth/models/hunyuan_video_dit.py
  function HunyuanVideoRope (line 10) | def HunyuanVideoRope(latents):
  class PatchEmbed (line 196) | class PatchEmbed(torch.nn.Module):
    method __init__ (line 197) | def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
    method forward (line 201) | def forward(self, x):
  class IndividualTokenRefinerBlock (line 207) | class IndividualTokenRefinerBlock(torch.nn.Module):
    method __init__ (line 208) | def __init__(self, hidden_size=3072, num_heads=24):
    method forward (line 226) | def forward(self, x, c, attn_mask=None):
  class SingleTokenRefiner (line 242) | class SingleTokenRefiner(torch.nn.Module):
    method __init__ (line 243) | def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
    method forward (line 254) | def forward(self, x, t, mask=None):
  class ModulateDiT (line 275) | class ModulateDiT(torch.nn.Module):
    method __init__ (line 276) | def __init__(self, hidden_size, factor=6):
    method forward (line 281) | def forward(self, x):
  function modulate (line 285) | def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr...
  function reshape_for_broadcast (line 301) | def reshape_for_broadcast(
  function rotate_half (line 347) | def rotate_half(x):
  function apply_rotary_emb (line 354) | def apply_rotary_emb(
  function attention (line 387) | def attention(q, k, v):
  function apply_gate (line 394) | def apply_gate(x, gate, tr_gate=None, tr_token=None):
  class MMDoubleStreamBlockComponent (line 403) | class MMDoubleStreamBlockComponent(torch.nn.Module):
    method __init__ (line 404) | def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
    method forward (line 423) | def forward(self, hidden_states, conditioning, freqs_cis=None, token_r...
    method process_ff (line 444) | def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_...
  class MMDoubleStreamBlock (line 456) | class MMDoubleStreamBlock(torch.nn.Module):
    method __init__ (line 457) | def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
    method forward (line 462) | def forward(self, hidden_states_a, hidden_states_b, conditioning, freq...
  class MMSingleStreamBlockOriginal (line 478) | class MMSingleStreamBlockOriginal(torch.nn.Module):
    method __init__ (line 479) | def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
    method forward (line 496) | def forward(self, x, vec, freqs_cis=None, txt_len=256):
  class MMSingleStreamBlock (line 518) | class MMSingleStreamBlock(torch.nn.Module):
    method __init__ (line 519) | def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
    method forward (line 537) | def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len...
  class FinalLayer (line 573) | class FinalLayer(torch.nn.Module):
    method __init__ (line 574) | def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channel...
    method forward (line 582) | def forward(self, x, c):
  class HunyuanVideoDiT (line 589) | class HunyuanVideoDiT(torch.nn.Module):
    method __init__ (line 590) | def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, nu...
    method unpatchify (line 612) | def unpatchify(self, x, T, H, W):
    method enable_block_wise_offload (line 616) | def enable_block_wise_offload(self, warm_device="cuda", cold_device="c...
    method load_models_to_device (line 621) | def load_models_to_device(self, loadmodel_names=[], device="cpu"):
    method prepare_freqs (line 628) | def prepare_freqs(self, latents):
    method forward (line 631) | def forward(
    method enable_auto_offload (line 664) | def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
    method state_dict_converter (line 810) | def state_dict_converter():
  class HunyuanVideoDiTStateDictConverter (line 814) | class HunyuanVideoDiTStateDictConverter:
    method __init__ (line 815) | def __init__(self):
    method from_civitai (line 818) | def from_civitai(self, state_dict):

FILE: diffsynth/models/hunyuan_video_text_encoder.py
  class HunyuanVideoLLMEncoder (line 6) | class HunyuanVideoLLMEncoder(LlamaModel):
    method __init__ (line 8) | def __init__(self, config: LlamaConfig):
    method enable_auto_offload (line 12) | def enable_auto_offload(self, **kwargs):
    method forward (line 15) | def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
  class HunyuanVideoMLLMEncoder (line 52) | class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
    method __init__ (line 54) | def __init__(self, config):
    method enable_auto_offload (line 58) | def enable_auto_offload(self, **kwargs):
    method forward (line 62) | def forward(self, input_ids, pixel_values, attention_mask, hidden_stat...

FILE: diffsynth/models/hunyuan_video_vae_decoder.py
  class CausalConv3d (line 10) | class CausalConv3d(nn.Module):
    method __init__ (line 12) | def __init__(self, in_channel, out_channel, kernel_size, stride=1, dil...
    method forward (line 19) | def forward(self, x):
  class UpsampleCausal3D (line 24) | class UpsampleCausal3D(nn.Module):
    method __init__ (line 26) | def __init__(self, channels, use_conv=False, out_channels=None, kernel...
    method forward (line 36) | def forward(self, hidden_states):
  class ResnetBlockCausal3D (line 64) | class ResnetBlockCausal3D(nn.Module):
    method __init__ (line 66) | def __init__(self, in_channels, out_channels=None, dropout=0.0, groups...
    method forward (line 86) | def forward(self, input_tensor):
  function prepare_causal_attention_mask (line 107) | def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_si...
  class Attention (line 118) | class Attention(nn.Module):
    method __init__ (line 120) | def __init__(self,
    method forward (line 140) | def forward(self, input_tensor, attn_mask=None):
  class UNetMidBlockCausal3D (line 162) | class UNetMidBlockCausal3D(nn.Module):
    method __init__ (line 164) | def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, n...
    method forward (line 203) | def forward(self, hidden_states):
  class UpDecoderBlockCausal3D (line 216) | class UpDecoderBlockCausal3D(nn.Module):
    method __init__ (line 218) | def __init__(
    method forward (line 254) | def forward(self, hidden_states):
  class DecoderCausal3D (line 263) | class DecoderCausal3D(nn.Module):
    method __init__ (line 265) | def __init__(
    method forward (line 331) | def forward(self, hidden_states):
  class HunyuanVideoVAEDecoder (line 369) | class HunyuanVideoVAEDecoder(nn.Module):
    method __init__ (line 371) | def __init__(
    method forward (line 401) | def forward(self, latents):
    method build_1d_mask (line 408) | def build_1d_mask(self, length, left_bound, right_bound, border_width):
    method build_mask (line 417) | def build_mask(self, data, is_bound, border_width):
    method tile_forward (line 432) | def tile_forward(self, hidden_states, tile_size, tile_stride):
    method decode_video (line 488) | def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(1...
    method state_dict_converter (line 493) | def state_dict_converter():
  class HunyuanVideoVAEDecoderStateDictConverter (line 497) | class HunyuanVideoVAEDecoderStateDictConverter:
    method __init__ (line 499) | def __init__(self):
    method from_diffusers (line 502) | def from_diffusers(self, state_dict):

FILE: diffsynth/models/hunyuan_video_vae_encoder.py
  class DownsampleCausal3D (line 10) | class DownsampleCausal3D(nn.Module):
    method __init__ (line 12) | def __init__(self, channels, out_channels, kernel_size=3, bias=True, s...
    method forward (line 16) | def forward(self, hidden_states):
  class DownEncoderBlockCausal3D (line 21) | class DownEncoderBlockCausal3D(nn.Module):
    method __init__ (line 23) | def __init__(
    method forward (line 57) | def forward(self, hidden_states):
  class EncoderCausal3D (line 68) | class EncoderCausal3D(nn.Module):
    method __init__ (line 70) | def __init__(
    method forward (line 129) | def forward(self, hidden_states):
  class HunyuanVideoVAEEncoder (line 167) | class HunyuanVideoVAEEncoder(nn.Module):
    method __init__ (line 169) | def __init__(
    method forward (line 199) | def forward(self, images):
    method build_1d_mask (line 207) | def build_1d_mask(self, length, left_bound, right_bound, border_width):
    method build_mask (line 216) | def build_mask(self, data, is_bound, border_width):
    method tile_forward (line 231) | def tile_forward(self, hidden_states, tile_size, tile_stride):
    method encode_video (line 287) | def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=...
    method state_dict_converter (line 293) | def state_dict_converter():
  class HunyuanVideoVAEEncoderStateDictConverter (line 297) | class HunyuanVideoVAEEncoderStateDictConverter:
    method __init__ (line 299) | def __init__(self):
    method from_diffusers (line 302) | def from_diffusers(self, state_dict):

FILE: diffsynth/models/kolors_text_encoder.py
  class Kernel (line 52) | class Kernel:
    method __init__ (line 53) | def __init__(self, code: bytes, function_names: List[str]):
  class W8A16Linear (line 78) | class W8A16Linear(torch.autograd.Function):
    method forward (line 80) | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: to...
    method backward (line 92) | def backward(ctx, grad_output: torch.Tensor):
  function compress_int4_weight (line 101) | def compress_int4_weight(weight: torch.Tensor):  # (n, m)
  function extract_weight_to_half (line 122) | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tenso...
  class QuantizedLinear (line 158) | class QuantizedLinear(torch.nn.Module):
    method __init__ (line 159) | def __init__(self, weight_bit_width: int, weight, bias=None, device="c...
    method forward (line 180) | def forward(self, input):
  function quantize (line 187) | def quantize(model, weight_bit_width, empty_init=False, device=None):
  class ChatGLMConfig (line 227) | class ChatGLMConfig(PretrainedConfig):
    method __init__ (line 229) | def __init__(
  function default_init (line 307) | def default_init(cls, *args, **kwargs):
  class InvalidScoreLogitsProcessor (line 311) | class InvalidScoreLogitsProcessor(LogitsProcessor):
    method __call__ (line 312) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  class PrefixEncoder (line 319) | class PrefixEncoder(torch.nn.Module):
    method __init__ (line 326) | def __init__(self, config: ChatGLMConfig):
    method forward (line 342) | def forward(self, prefix: torch.Tensor):
  function split_tensor_along_last_dim (line 351) | def split_tensor_along_last_dim(
  class RotaryEmbedding (line 379) | class RotaryEmbedding(nn.Module):
    method __init__ (line 380) | def __init__(self, dim, original_impl=False, device=None, dtype=None):
    method forward_impl (line 387) | def forward_impl(
    method forward (line 412) | def forward(self, max_seq_len, offset=0):
  function apply_rotary_pos_emb (line 419) | def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> t...
  class RMSNorm (line 439) | class RMSNorm(torch.nn.Module):
    method __init__ (line 440) | def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None...
    method forward (line 445) | def forward(self, hidden_states: torch.Tensor):
  class CoreAttention (line 453) | class CoreAttention(torch.nn.Module):
    method __init__ (line 454) | def __init__(self, config: ChatGLMConfig, layer_number):
    method forward (line 479) | def forward(self, query_layer, key_layer, value_layer, attention_mask):
  class SelfAttention (line 571) | class SelfAttention(torch.nn.Module):
    method __init__ (line 578) | def __init__(self, config: ChatGLMConfig, layer_number, device=None):
    method _allocate_memory (line 607) | def _allocate_memory(self, inference_max_sequence_len, batch_size, dev...
    method forward (line 621) | def forward(
  function _config_to_kwargs (line 710) | def _config_to_kwargs(args):
  class MLP (line 717) | class MLP(torch.nn.Module):
    method __init__ (line 725) | def __init__(self, config: ChatGLMConfig, device=None):
    method forward (line 754) | def forward(self, hidden_states):
  class GLMBlock (line 763) | class GLMBlock(torch.nn.Module):
    method __init__ (line 770) | def __init__(self, config: ChatGLMConfig, layer_number, device=None):
    method forward (line 794) | def forward(
  class GLMTransformer (line 837) | class GLMTransformer(torch.nn.Module):
    method __init__ (line 840) | def __init__(self, config: ChatGLMConfig, device=None):
    method _get_layer (line 863) | def _get_layer(self, layer_number):
    method forward (line 866) | def forward(
  class ChatGLMPreTrainedModel (line 919) | class ChatGLMPreTrainedModel(PreTrainedModel):
    method _init_weights (line 931) | def _init_weights(self, module: nn.Module):
    method get_masks (line 935) | def get_masks(self, input_ids, past_key_values, padding_mask=None):
    method get_position_ids (line 953) | def get_position_ids(self, input_ids, device):
    method _set_gradient_checkpointing (line 958) | def _set_gradient_checkpointing(self, module, value=False):
  class Embedding (line 963) | class Embedding(torch.nn.Module):
    method __init__ (line 966) | def __init__(self, config: ChatGLMConfig, device=None):
    method forward (line 979) | def forward(self, input_ids):
  class ChatGLMModel (line 991) | class ChatGLMModel(ChatGLMPreTrainedModel):
    method __init__ (line 992) | def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
    method get_input_embeddings (line 1026) | def get_input_embeddings(self):
    method get_prompt (line 1029) | def get_prompt(self, batch_size, device, dtype=torch.half):
    method forward (line 1044) | def forward(
    method quantize (line 1103) | def quantize(self, weight_bit_width: int):
  class ChatGLMForConditionalGeneration (line 1109) | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    method __init__ (line 1110) | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
    method _update_model_kwargs_for_generation (line 1121) | def _update_model_kwargs_for_generation(
    method prepare_inputs_for_generation (line 1152) | def prepare_inputs_for_generation(
    method forward (line 1178) | def forward(
    method _reorder_cache (line 1239) | def _reorder_cache(
    method process_response (line 1257) | def process_response(self, output, history):
    method chat (line 1279) | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] =...
    method stream_chat (line 1301) | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, ...
    method stream_generate (line 1342) | def stream_generate(
    method quantize (line 1449) | def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
  class ChatGLMForSequenceClassification (line 1468) | class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
    method __init__ (line 1469) | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
    method forward (line 1485) | def forward(

FILE: diffsynth/models/lora.py
  class LoRAFromCivitai (line 15) | class LoRAFromCivitai:
    method __init__ (line 16) | def __init__(self):
    method convert_state_dict (line 23) | def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alp...
    method convert_state_dict_up_down (line 30) | def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_une...
    method convert_state_dict_AB (line 53) | def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0,...
    method load (line 76) | def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_r...
    method match (line 99) | def match(self, model, state_dict_lora):
  class SDLoRAFromCivitai (line 125) | class SDLoRAFromCivitai(LoRAFromCivitai):
    method __init__ (line 126) | def __init__(self):
  class SDXLLoRAFromCivitai (line 152) | class SDXLLoRAFromCivitai(LoRAFromCivitai):
    method __init__ (line 153) | def __init__(self):
  class FluxLoRAFromCivitai (line 181) | class FluxLoRAFromCivitai(LoRAFromCivitai):
    method __init__ (line 182) | def __init__(self):
  class GeneralLoRAFromPeft (line 200) | class GeneralLoRAFromPeft:
    method __init__ (line 201) | def __init__(self):
    method get_name_dict (line 205) | def get_name_dict(self, lora_state_dict):
    method match (line 221) | def match(self, model: torch.nn.Module, state_dict_lora):
    method fetch_device_and_dtype (line 231) | def fetch_device_and_dtype(self, state_dict):
    method load (line 246) | def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, mode...
  class HunyuanVideoLoRAFromCivitai (line 267) | class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
    method __init__ (line 268) | def __init__(self):
  class FluxLoRAConverter (line 275) | class FluxLoRAConverter:
    method __init__ (line 276) | def __init__(self):
    method align_to_opensource_format (line 280) | def align_to_opensource_format(state_dict, alpha=1.0):
    method align_to_diffsynth_format (line 323) | def align_to_diffsynth_format(state_dict):
  function get_lora_loaders (line 370) | def get_lora_loaders():

FILE: diffsynth/models/model_manager.py
  function load_model_from_single_file (line 56) | def load_model_from_single_file(state_dict, model_names, model_classes, ...
  function load_model_from_huggingface_folder (line 82) | def load_model_from_huggingface_folder(file_path, model_names, model_cla...
  function load_single_patch_model_from_single_file (line 100) | def load_single_patch_model_from_single_file(state_dict, model_name, mod...
  function load_patch_model_from_single_file (line 112) | def load_patch_model_from_single_file(state_dict, model_names, model_cla...
  class ModelDetectorTemplate (line 136) | class ModelDetectorTemplate:
    method __init__ (line 137) | def __init__(self):
    method match (line 140) | def match(self, file_path="", state_dict={}):
    method load (line 143) | def load(self, file_path="", state_dict={}, device="cuda", torch_dtype...
  class ModelDetectorFromSingleFile (line 148) | class ModelDetectorFromSingleFile:
    method __init__ (line 149) | def __init__(self, model_loader_configs=[]):
    method add_model_metadata (line 156) | def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_na...
    method match (line 162) | def match(self, file_path="", state_dict={}):
    method load (line 176) | def load(self, file_path="", state_dict={}, device="cuda", torch_dtype...
  class ModelDetectorFromSplitedSingleFile (line 199) | class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
    method __init__ (line 200) | def __init__(self, model_loader_configs=[]):
    method match (line 204) | def match(self, file_path="", state_dict={}):
    method load (line 216) | def load(self, file_path="", state_dict={}, device="cuda", torch_dtype...
  class ModelDetectorFromHuggingfaceFolder (line 236) | class ModelDetectorFromHuggingfaceFolder:
    method __init__ (line 237) | def __init__(self, model_loader_configs=[]):
    method add_model_metadata (line 243) | def add_model_metadata(self, architecture, huggingface_lib, model_name...
    method match (line 247) | def match(self, file_path="", state_dict={}):
    method load (line 260) | def load(self, file_path="", state_dict={}, device="cuda", torch_dtype...
  class ModelDetectorFromPatchedSingleFile (line 277) | class ModelDetectorFromPatchedSingleFile:
    method __init__ (line 278) | def __init__(self, model_loader_configs=[]):
    method add_model_metadata (line 284) | def add_model_metadata(self, keys_hash_with_shape, model_name, model_c...
    method match (line 288) | def match(self, file_path="", state_dict={}):
    method load (line 299) | def load(self, file_path="", state_dict={}, device="cuda", torch_dtype...
  class ModelManager (line 316) | class ModelManager:
    method __init__ (line 317) | def __init__(
    method load_model_from_single_file (line 340) | def load_model_from_single_file(self, file_path="", state_dict={}, mod...
    method load_model_from_huggingface_folder (line 352) | def load_model_from_huggingface_folder(self, file_path="", model_names...
    method load_patch_model_from_single_file (line 362) | def load_patch_model_from_single_file(self, file_path="", state_dict={...
    method load_lora (line 373) | def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
    method load_model (line 395) | def load_model(self, file_path, model_names=None, device=None, torch_d...
    method load_models (line 424) | def load_models(self, file_path_list, model_names=None, device=None, t...
    method fetch_model (line 429) | def fetch_model(self, model_name, file_path=None, require_model_path=F...
    method to (line 451) | def to(self, device):

FILE: diffsynth/models/omnigen.py
  class Phi3Transformer (line 20) | class Phi3Transformer(Phi3Model):
    method prefetch_layer (line 27) | def prefetch_layer(self, layer_idx: int, device: torch.device):
    method evict_previous_layer (line 34) | def evict_previous_layer(self, layer_idx: int):
    method get_offlaod_layer (line 40) | def get_offlaod_layer(self, layer_idx: int, device: torch.device):
    method forward (line 56) | def forward(
  function modulate (line 191) | def modulate(x, shift, scale):
  class TimestepEmbedder (line 195) | class TimestepEmbedder(nn.Module):
    method __init__ (line 199) | def __init__(self, hidden_size, frequency_embedding_size=256):
    method timestep_embedding (line 209) | def timestep_embedding(t, dim, max_period=10000):
    method forward (line 229) | def forward(self, t, dtype=torch.float32):
  class FinalLayer (line 235) | class FinalLayer(nn.Module):
    method __init__ (line 239) | def __init__(self, hidden_size, patch_size, out_channels):
    method forward (line 248) | def forward(self, x, c):
  function get_2d_sincos_pos_embed (line 255) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra...
  function get_2d_sincos_pos_embed_from_grid (line 275) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  function get_1d_sincos_pos_embed_from_grid (line 286) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  class PatchEmbedMR (line 307) | class PatchEmbedMR(nn.Module):
    method __init__ (line 310) | def __init__(
    method forward (line 320) | def forward(self, x):
  class OmniGenOriginalModel (line 326) | class OmniGenOriginalModel(nn.Module):
    method __init__ (line 330) | def __init__(
    method from_pretrained (line 364) | def from_pretrained(cls, model_name):
    method initialize_weights (line 380) | def initialize_weights(self):
    method unpatchify (line 413) | def unpatchify(self, x, h, w):
    method cropped_pos_embed (line 426) | def cropped_pos_embed(self, height, width):
    method patch_multiple_resolutions (line 451) | def patch_multiple_resolutions(self, latents, padding_latent=None, is_...
    method forward (line 489) | def forward(self, x, timestep, input_ids, input_img_latents, input_ima...
    method forward_with_cfg (line 534) | def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, ...
    method forward_with_separate_cfg (line 550) | def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_...
  class OmniGenTransformer (line 580) | class OmniGenTransformer(OmniGenOriginalModel):
    method __init__ (line 581) | def __init__(self):
    method forward (line 717) | def forward(self, x, timestep, input_ids, input_img_latents, input_ima...
    method forward_with_separate_cfg (line 760) | def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_...
    method state_dict_converter (line 790) | def state_dict_converter():
  class OmniGenTransformerStateDictConverter (line 795) | class OmniGenTransformerStateDictConverter:
    method __init__ (line 796) | def __init__(self):
    method from_diffusers (line 799) | def from_diffusers(self, state_dict):
    method from_civitai (line 802) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd3_dit.py
  class RMSNorm (line 8) | class RMSNorm(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, dim, eps, elementwise_affine=True):
    method forward (line 17) | def forward(self, hidden_states):
  class PatchEmbed (line 28) | class PatchEmbed(torch.nn.Module):
    method __init__ (line 29) | def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_e...
    method cropped_pos_embed (line 37) | def cropped_pos_embed(self, height, width):
    method forward (line 45) | def forward(self, latent):
  class TimestepEmbeddings (line 54) | class TimestepEmbeddings(torch.nn.Module):
    method __init__ (line 55) | def __init__(self, dim_in, dim_out, computation_device=None):
    method forward (line 62) | def forward(self, timestep, dtype):
  class AdaLayerNorm (line 69) | class AdaLayerNorm(torch.nn.Module):
    method __init__ (line 70) | def __init__(self, dim, single=False, dual=False):
    method forward (line 77) | def forward(self, x, emb):
  class JointAttention (line 96) | class JointAttention(torch.nn.Module):
    method __init__ (line 97) | def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False...
    method process_qkv (line 122) | def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
    method forward (line 134) | def forward(self, hidden_states_a, hidden_states_b):
  class SingleAttention (line 156) | class SingleAttention(torch.nn.Module):
    method __init__ (line 157) | def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
    method process_qkv (line 173) | def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
    method forward (line 185) | def forward(self, hidden_states_a):
  class DualTransformerBlock (line 197) | class DualTransformerBlock(torch.nn.Module):
    method __init__ (line 198) | def __init__(self, dim, num_attention_heads, use_rms_norm=False):
    method forward (line 221) | def forward(self, hidden_states_a, hidden_states_b, temb):
  class JointTransformerBlock (line 243) | class JointTransformerBlock(torch.nn.Module):
    method __init__ (line 244) | def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=...
    method forward (line 268) | def forward(self, hidden_states_a, hidden_states_b, temb):
  class JointTransformerFinalBlock (line 294) | class JointTransformerFinalBlock(torch.nn.Module):
    method __init__ (line 295) | def __init__(self, dim, num_attention_heads, use_rms_norm=False):
    method forward (line 310) | def forward(self, hidden_states_a, hidden_states_b, temb):
  class SD3DiT (line 326) | class SD3DiT(torch.nn.Module):
    method __init__ (line 327) | def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, ...
    method tiled_forward (line 339) | def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_pr...
    method forward (line 351) | def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_e...
    method state_dict_converter (line 381) | def state_dict_converter():
  class SD3DiTStateDictConverter (line 386) | class SD3DiTStateDictConverter:
    method __init__ (line 387) | def __init__(self):
    method infer_architecture (line 390) | def infer_architecture(self, state_dict):
    method from_diffusers (line 408) | def from_diffusers(self, state_dict):
    method from_civitai (line 472) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd3_text_encoder.py
  class SD3TextEncoder1 (line 7) | class SD3TextEncoder1(SDTextEncoder):
    method __init__ (line 8) | def __init__(self, vocab_size=49408):
    method forward (line 11) | def forward(self, input_ids, clip_skip=2, extra_mask=None):
    method state_dict_converter (line 26) | def state_dict_converter():
  class SD3TextEncoder2 (line 31) | class SD3TextEncoder2(SDXLTextEncoder2):
    method __init__ (line 32) | def __init__(self):
    method state_dict_converter (line 36) | def state_dict_converter():
  class SD3TextEncoder3 (line 40) | class SD3TextEncoder3(T5EncoderModel):
    method __init__ (line 41) | def __init__(self):
    method forward (line 75) | def forward(self, input_ids):
    method state_dict_converter (line 81) | def state_dict_converter():
  class SD3TextEncoder1StateDictConverter (line 86) | class SD3TextEncoder1StateDictConverter:
    method __init__ (line 87) | def __init__(self):
    method from_diffusers (line 90) | def from_diffusers(self, state_dict):
    method from_civitai (line 122) | def from_civitai(self, state_dict):
  class SD3TextEncoder2StateDictConverter (line 337) | class SD3TextEncoder2StateDictConverter(SDXLTextEncoder2StateDictConvert...
    method __init__ (line 338) | def __init__(self):
    method from_diffusers (line 341) | def from_diffusers(self, state_dict):
    method from_civitai (line 344) | def from_civitai(self, state_dict):
  class SD3TextEncoder3StateDictConverter (line 880) | class SD3TextEncoder3StateDictConverter():
    method __init__ (line 881) | def __init__(self):
    method from_diffusers (line 884) | def from_diffusers(self, state_dict):
    method from_civitai (line 888) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd3_vae_decoder.py
  class SD3VAEDecoder (line 8) | class SD3VAEDecoder(torch.nn.Module):
    method __init__ (line 9) | def __init__(self):
    method tiled_forward (line 45) | def tiled_forward(self, sample, tile_size=64, tile_stride=32):
    method forward (line 56) | def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, *...
    method state_dict_converter (line 80) | def state_dict_converter():

FILE: diffsynth/models/sd3_vae_encoder.py
  class SD3VAEEncoder (line 8) | class SD3VAEEncoder(torch.nn.Module):
    method __init__ (line 9) | def __init__(self):
    method tiled_forward (line 41) | def tiled_forward(self, sample, tile_size=64, tile_stride=32):
    method forward (line 52) | def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, *...
    method encode_video (line 76) | def encode_video(self, sample, batch_size=8):
    method state_dict_converter (line 94) | def state_dict_converter():

FILE: diffsynth/models/sd_controlnet.py
  class ControlNetConditioningLayer (line 6) | class ControlNetConditioningLayer(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
    method forward (line 19) | def forward(self, conditioning):
  class SDControlNet (line 25) | class SDControlNet(torch.nn.Module):
    method __init__ (line 26) | def __init__(self, global_pool=False):
    method forward (line 96) | def forward(
    method state_dict_converter (line 139) | def state_dict_converter():
  class SDControlNetStateDictConverter (line 143) | class SDControlNetStateDictConverter:
    method __init__ (line 144) | def __init__(self):
    method from_diffusers (line 147) | def from_diffusers(self, state_dict):
    method from_civitai (line 236) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd_ipadapter.py
  class IpAdapterCLIPImageEmbedder (line 7) | class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
    method __init__ (line 8) | def __init__(self):
    method forward (line 12) | def forward(self, image):
  class SDIpAdapter (line 18) | class SDIpAdapter(torch.nn.Module):
    method __init__ (line 19) | def __init__(self):
    method set_full_adapter (line 26) | def set_full_adapter(self):
    method set_less_adapter (line 30) | def set_less_adapter(self):
    method forward (line 34) | def forward(self, hidden_states, scale=1.0):
    method state_dict_converter (line 51) | def state_dict_converter():
  class SDIpAdapterStateDictConverter (line 55) | class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
    method __init__ (line 56) | def __init__(self):

FILE: diffsynth/models/sd_motion.py
  class TemporalTransformerBlock (line 6) | class TemporalTransformerBlock(torch.nn.Module):
    method __init__ (line 8) | def __init__(self, dim, num_attention_heads, attention_head_dim, max_p...
    method forward (line 27) | def forward(self, hidden_states, batch_size=1):
  class TemporalBlock (line 52) | class TemporalBlock(torch.nn.Module):
    method __init__ (line 54) | def __init__(self, num_attention_heads, attention_head_dim, in_channel...
    method forward (line 72) | def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_...
  class SDMotionModel (line 94) | class SDMotionModel(torch.nn.Module):
    method __init__ (line 95) | def __init__(self):
    method forward (line 144) | def forward(self):
    method state_dict_converter (line 148) | def state_dict_converter():
  class SDMotionModelStateDictConverter (line 152) | class SDMotionModelStateDictConverter:
    method __init__ (line 153) | def __init__(self):
    method from_diffusers (line 156) | def from_diffusers(self, state_dict):
    method from_civitai (line 198) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd_text_encoder.py
  class CLIPEncoderLayer (line 5) | class CLIPEncoderLayer(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, embed_dim, intermediate_size, num_heads=12, head_di...
    method quickGELU (line 16) | def quickGELU(self, x):
    method forward (line 19) | def forward(self, hidden_states, attn_mask=None):
  class SDTextEncoder (line 39) | class SDTextEncoder(torch.nn.Module):
    method __init__ (line 40) | def __init__(self, embed_dim=768, vocab_size=49408, max_position_embed...
    method attention_mask (line 58) | def attention_mask(self, length):
    method forward (line 64) | def forward(self, input_ids, clip_skip=1):
    method state_dict_converter (line 75) | def state_dict_converter():
  class SDTextEncoderStateDictConverter (line 79) | class SDTextEncoderStateDictConverter:
    method __init__ (line 80) | def __init__(self):
    method from_diffusers (line 83) | def from_diffusers(self, state_dict):
    method from_civitai (line 115) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd_unet.py
  class Timesteps (line 6) | class Timesteps(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, num_channels):
    method forward (line 11) | def forward(self, timesteps):
  class GEGLU (line 20) | class GEGLU(torch.nn.Module):
    method __init__ (line 22) | def __init__(self, dim_in, dim_out):
    method forward (line 26) | def forward(self, hidden_states):
  class BasicTransformerBlock (line 31) | class BasicTransformerBlock(torch.nn.Module):
    method __init__ (line 33) | def __init__(self, dim, num_attention_heads, attention_head_dim, cross...
    method forward (line 50) | def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwar...
  class DownSampler (line 70) | class DownSampler(torch.nn.Module):
    method __init__ (line 71) | def __init__(self, channels, padding=1, extra_padding=False):
    method forward (line 76) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class UpSampler (line 83) | class UpSampler(torch.nn.Module):
    method __init__ (line 84) | def __init__(self, channels):
    method forward (line 88) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class ResnetBlock (line 94) | class ResnetBlock(torch.nn.Module):
    method __init__ (line 95) | def __init__(self, in_channels, out_channels, temb_channels=None, grou...
    method forward (line 108) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class AttentionBlock (line 126) | class AttentionBlock(torch.nn.Module):
    method __init__ (line 128) | def __init__(self, num_attention_heads, attention_head_dim, in_channel...
    method forward (line 148) | def forward(
  class PushBlock (line 211) | class PushBlock(torch.nn.Module):
    method __init__ (line 212) | def __init__(self):
    method forward (line 215) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class PopBlock (line 220) | class PopBlock(torch.nn.Module):
    method __init__ (line 221) | def __init__(self):
    method forward (line 224) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class SDUNet (line 230) | class SDUNet(torch.nn.Module):
    method __init__ (line 231) | def __init__(self):
    method forward (line 324) | def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
    method state_dict_converter (line 346) | def state_dict_converter():
  class SDUNetStateDictConverter (line 350) | class SDUNetStateDictConverter:
    method __init__ (line 351) | def __init__(self):
    method from_diffusers (line 354) | def from_diffusers(self, state_dict):
    method from_civitai (line 412) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd_vae_decoder.py
  class VAEAttentionBlock (line 7) | class VAEAttentionBlock(torch.nn.Module):
    method __init__ (line 9) | def __init__(self, num_attention_heads, attention_head_dim, in_channel...
    method forward (line 27) | def forward(self, hidden_states, time_emb, text_emb, res_stack):
  class SDVAEDecoder (line 44) | class SDVAEDecoder(torch.nn.Module):
    method __init__ (line 45) | def __init__(self):
    method tiled_forward (line 81) | def tiled_forward(self, sample, tile_size=64, tile_stride=32):
    method forward (line 92) | def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, *...
    method state_dict_converter (line 120) | def state_dict_converter():
  class SDVAEDecoderStateDictConverter (line 124) | class SDVAEDecoderStateDictConverter:
    method __init__ (line 125) | def __init__(self):
    method from_diffusers (line 128) | def from_diffusers(self, state_dict):
    method from_civitai (line 186) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sd_vae_encoder.py
  class SDVAEEncoder (line 8) | class SDVAEEncoder(torch.nn.Module):
    method __init__ (line 9) | def __init__(self):
    method tiled_forward (line 41) | def tiled_forward(self, sample, tile_size=64, tile_stride=32):
    method forward (line 52) | def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, *...
    method encode_video (line 80) | def encode_video(self, sample, batch_size=8):
    method state_dict_converter (line 98) | def state_dict_converter():
  class SDVAEEncoderStateDictConverter (line 102) | class SDVAEEncoderStateDictConverter:
    method __init__ (line 103) | def __init__(self):
    method from_diffusers (line 106) | def from_diffusers(self, state_dict):
    method from_civitai (line 164) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_controlnet.py
  class QuickGELU (line 10) | class QuickGELU(torch.nn.Module):
    method forward (line 12) | def forward(self, x: torch.Tensor):
  class ResidualAttentionBlock (line 17) | class ResidualAttentionBlock(torch.nn.Module):
    method __init__ (line 19) | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor ...
    method attention (line 32) | def attention(self, x: torch.Tensor):
    method forward (line 36) | def forward(self, x: torch.Tensor):
  class SDXLControlNetUnion (line 43) | class SDXLControlNetUnion(torch.nn.Module):
    method __init__ (line 44) | def __init__(self, global_pool=False):
    method fuse_condition_to_input (line 137) | def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
    method forward (line 151) | def forward(
    method state_dict_converter (line 217) | def state_dict_converter():
  class SDXLControlNetUnionStateDictConverter (line 222) | class SDXLControlNetUnionStateDictConverter:
    method __init__ (line 223) | def __init__(self):
    method from_diffusers (line 226) | def from_diffusers(self, state_dict):
    method from_civitai (line 317) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_ipadapter.py
  class IpAdapterXLCLIPImageEmbedder (line 6) | class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
    method __init__ (line 7) | def __init__(self):
    method forward (line 11) | def forward(self, image):
  class IpAdapterImageProjModel (line 17) | class IpAdapterImageProjModel(torch.nn.Module):
    method __init__ (line 18) | def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280,...
    method forward (line 25) | def forward(self, image_embeds):
  class IpAdapterModule (line 31) | class IpAdapterModule(torch.nn.Module):
    method __init__ (line 32) | def __init__(self, input_dim, output_dim):
    method forward (line 37) | def forward(self, hidden_states):
  class SDXLIpAdapter (line 43) | class SDXLIpAdapter(torch.nn.Module):
    method __init__ (line 44) | def __init__(self):
    method set_full_adapter (line 51) | def set_full_adapter(self):
    method set_less_adapter (line 67) | def set_less_adapter(self):
    method forward (line 83) | def forward(self, hidden_states, scale=1.0):
    method state_dict_converter (line 100) | def state_dict_converter():
  class SDXLIpAdapterStateDictConverter (line 104) | class SDXLIpAdapterStateDictConverter:
    method __init__ (line 105) | def __init__(self):
    method from_diffusers (line 108) | def from_diffusers(self, state_dict):
    method from_civitai (line 120) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_motion.py
  class SDXLMotionModel (line 6) | class SDXLMotionModel(torch.nn.Module):
    method __init__ (line 7) | def __init__(self):
    method forward (line 49) | def forward(self):
    method state_dict_converter (line 53) | def state_dict_converter():
  class SDMotionModelStateDictConverter (line 57) | class SDMotionModelStateDictConverter:
    method __init__ (line 58) | def __init__(self):
    method from_diffusers (line 61) | def from_diffusers(self, state_dict):
    method from_civitai (line 103) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_text_encoder.py
  class SDXLTextEncoder (line 5) | class SDXLTextEncoder(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, embed_dim=768, vocab_size=49408, max_position_embed...
    method attention_mask (line 24) | def attention_mask(self, length):
    method forward (line 30) | def forward(self, input_ids, clip_skip=1):
    method state_dict_converter (line 40) | def state_dict_converter():
  class SDXLTextEncoder2 (line 44) | class SDXLTextEncoder2(torch.nn.Module):
    method __init__ (line 45) | def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embe...
    method attention_mask (line 66) | def attention_mask(self, length):
    method forward (line 72) | def forward(self, input_ids, clip_skip=2):
    method state_dict_converter (line 85) | def state_dict_converter():
  class SDXLTextEncoderStateDictConverter (line 89) | class SDXLTextEncoderStateDictConverter:
    method __init__ (line 90) | def __init__(self):
    method from_diffusers (line 93) | def from_diffusers(self, state_dict):
    method from_civitai (line 125) | def from_civitai(self, state_dict):
  class SDXLTextEncoder2StateDictConverter (line 316) | class SDXLTextEncoder2StateDictConverter:
    method __init__ (line 317) | def __init__(self):
    method from_diffusers (line 320) | def from_diffusers(self, state_dict):
    method from_civitai (line 353) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_unet.py
  class SDXLUNet (line 5) | class SDXLUNet(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, is_kolors=False):
    method forward (line 88) | def forward(
    method state_dict_converter (line 139) | def state_dict_converter():
  class SDXLUNetStateDictConverter (line 143) | class SDXLUNetStateDictConverter:
    method __init__ (line 144) | def __init__(self):
    method from_diffusers (line 147) | def from_diffusers(self, state_dict):
    method from_civitai (line 208) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_vae_decoder.py
  class SDXLVAEDecoder (line 4) | class SDXLVAEDecoder(SDVAEDecoder):
    method __init__ (line 5) | def __init__(self, upcast_to_float32=True):
    method state_dict_converter (line 10) | def state_dict_converter():
  class SDXLVAEDecoderStateDictConverter (line 14) | class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
    method __init__ (line 15) | def __init__(self):
    method from_diffusers (line 18) | def from_diffusers(self, state_dict):
    method from_civitai (line 22) | def from_civitai(self, state_dict):

FILE: diffsynth/models/sdxl_vae_encoder.py
  class SDXLVAEEncoder (line 4) | class SDXLVAEEncoder(SDVAEEncoder):
    method __init__ (line 5) | def __init__(self, upcast_to_float32=True):
    method state_dict_converter (line 10) | def state_dict_converter():
  class SDXLVAEEncoderStateDictConverter (line 14) | class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
    method __init__ (line 15) | def __init__(self):
    method from_diffusers (line 18) | def from_diffusers(self, state_dict):
    method from_civitai (line 22) | def from_civitai(self, state_dict):

FILE: diffsynth/models/stepvideo_dit.py
  class RMSNorm (line 20) | class RMSNorm(nn.Module):
    method __init__ (line 21) | def __init__(
    method _norm (line 47) | def _norm(self, x):
    method forward (line 60) | def forward(self, x):
  function get_activation (line 86) | def get_activation(act_fn: str) -> nn.Module:
  function get_timestep_embedding (line 103) | def get_timestep_embedding(
  class Timesteps (line 146) | class Timesteps(nn.Module):
    method __init__ (line 147) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
    method forward (line 153) | def forward(self, timesteps):
  class TimestepEmbedding (line 163) | class TimestepEmbedding(nn.Module):
    method __init__ (line 164) | def __init__(
    method forward (line 210) | def forward(self, sample, condition=None):
  class PixArtAlphaCombinedTimestepSizeEmbeddings (line 225) | class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
    method __init__ (line 226) | def __init__(self, embedding_dim, size_emb_dim, use_additional_conditi...
    method forward (line 240) | def forward(self, timestep, resolution=None, nframe=None, fps=None):
  class AdaLayerNormSingle (line 264) | class AdaLayerNormSingle(nn.Module):
    method __init__ (line 274) | def __init__(self, embedding_dim: int, use_additional_conditions: bool...
    method forward (line 286) | def forward(
  class PixArtAlphaTextProjection (line 298) | class PixArtAlphaTextProjection(nn.Module):
    method __init__ (line 305) | def __init__(self, in_features, hidden_size):
    method forward (line 319) | def forward(self, caption):
  class Attention (line 326) | class Attention(nn.Module):
    method __init__ (line 327) | def __init__(self):
    method attn_processor (line 330) | def attn_processor(self, attn_type):
    method torch_attn_func (line 338) | def torch_attn_func(
  class RoPE1D (line 366) | class RoPE1D:
    method __init__ (line 367) | def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
    method get_cos_sin (line 373) | def get_cos_sin(self, D, seq_len, device, dtype):
    method rotate_half (line 385) | def rotate_half(x):
    method apply_rope1d (line 389) | def apply_rope1d(self, tokens, pos1d, cos, sin):
    method __call__ (line 395) | def __call__(self, tokens, positions):
  class RoPE3D (line 410) | class RoPE3D(RoPE1D):
    method __init__ (line 411) | def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
    method get_mesh_3d (line 415) | def get_mesh_3d(self, rope_positions, bsz):
    method __call__ (line 425) | def __call__(self, tokens, rope_positions, ch_split, parallel=False):
  class SelfAttention (line 451) | class SelfAttention(Attention):
    method __init__ (line 452) | def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, w...
    method apply_rope3d (line 473) | def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
    method forward (line 477) | def forward(
  class CrossAttention (line 512) | class CrossAttention(Attention):
    method __init__ (line 513) | def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True...
    method forward (line 529) | def forward(
  class GELU (line 560) | class GELU(nn.Module):
    method __init__ (line 571) | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none...
    method gelu (line 576) | def gelu(self, gate: torch.Tensor) -> torch.Tensor:
    method forward (line 579) | def forward(self, hidden_states):
  class FeedForward (line 585) | class FeedForward(nn.Module):
    method __init__ (line 586) | def __init__(
    method forward (line 604) | def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> tor...
  function modulate (line 610) | def modulate(x, scale, shift):
  function gate (line 615) | def gate(x, gate):
  class StepVideoTransformerBlock (line 620) | class StepVideoTransformerBlock(nn.Module):
    method __init__ (line 655) | def __init__(
    method forward (line 677) | def forward(
  class PatchEmbed (line 715) | class PatchEmbed(nn.Module):
    method __init__ (line 718) | def __init__(
    method forward (line 736) | def forward(self, latent):
  class StepVideoModel (line 746) | class StepVideoModel(torch.nn.Module):
    method __init__ (line 747) | def __init__(
    method patchfy (line 812) | def patchfy(self, hidden_states):
    method prepare_attn_mask (line 817) | def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_sta...
    method block_forward (line 826) | def block_forward(
    method forward (line 848) | def forward(
    method state_dict_converter (line 925) | def state_dict_converter():
  class StepVideoDiTStateDictConverter (line 929) | class StepVideoDiTStateDictConverter:
    method __init__ (line 930) | def __init__(self):
    method from_diffusers (line 933) | def from_diffusers(self, state_dict):
    method from_civitai (line 936) | def from_civitai(self, state_dict):

FILE: diffsynth/models/stepvideo_text_encoder.py
  class EmptyInitOnDevice (line 30) | class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
    method __init__ (line 31) | def __init__(self, device=None):
    method __torch_function__ (line 34) | def __torch_function__(self, func, types, args=(), kwargs=None):
  function with_empty_init (line 46) | def with_empty_init(func):
  class LLaMaEmbedding (line 55) | class LLaMaEmbedding(nn.Module):
    method __init__ (line 69) | def __init__(self,
    method forward (line 82) | def forward(self, input_ids):
  class StepChatTokenizer (line 105) | class StepChatTokenizer:
    method __init__ (line 108) | def __init__(
    method vocab (line 166) | def vocab(self):
    method inv_vocab (line 170) | def inv_vocab(self):
    method vocab_size (line 174) | def vocab_size(self):
    method tokenize (line 177) | def tokenize(self, text: str) -> List[int]:
    method detokenize (line 180) | def detokenize(self, token_ids: List[int]) -> str:
  class Tokens (line 184) | class Tokens:
    method __init__ (line 185) | def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens...
    method to (line 191) | def to(self, device):
  class Wrapped_StepChatTokenizer (line 198) | class Wrapped_StepChatTokenizer(StepChatTokenizer):
    method __call__ (line 199) | def __call__(self, text, max_length=320, padding="max_length", truncat...
  function flash_attn_func (line 245) | def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=T...
  class FlashSelfAttention (line 256) | class FlashSelfAttention(torch.nn.Module):
    method __init__ (line 257) | def __init__(
    method forward (line 265) | def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
  function safediv (line 275) | def safediv(n, d):
  class MultiQueryAttention (line 281) | class MultiQueryAttention(nn.Module):
    method __init__ (line 282) | def __init__(self, cfg, layer_id=None):
    method forward (line 311) | def forward(
  class FeedForward (line 373) | class FeedForward(nn.Module):
    method __init__ (line 374) | def __init__(
    method forward (line 401) | def forward(self, x):
  class TransformerBlock (line 408) | class TransformerBlock(nn.Module):
    method __init__ (line 409) | def __init__(
    method forward (line 438) | def forward(
  class Transformer (line 455) | class Transformer(nn.Module):
    method __init__ (line 456) | def __init__(
    method _build_layers (line 465) | def _build_layers(self, config):
    method forward (line 476) | def forward(
  class Step1Model (line 497) | class Step1Model(PreTrainedModel):
    method __init__ (line 500) | def __init__(
    method forward (line 508) | def forward(
  class STEP1TextEncoder (line 524) | class STEP1TextEncoder(torch.nn.Module):
    method __init__ (line 525) | def __init__(self, model_dir, max_length=320):
    method from_pretrained (line 533) | def from_pretrained(path, torch_dtype=torch.bfloat16):
    method forward (line 538) | def forward(self, prompts, with_mask=True, max_length=None, device="cu...

FILE: diffsynth/models/stepvideo_vae.py
  class BaseGroupNorm (line 21) | class BaseGroupNorm(nn.GroupNorm):
    method __init__ (line 22) | def __init__(self, num_groups, num_channels):
    method forward (line 25) | def forward(self, x, zero_pad=False, **kwargs):
  function base_group_norm (line 32) | def base_group_norm(x, norm_layer, act_silu=False, channel_last=False):
  function base_conv2d (line 62) | def base_conv2d(x, conv_layer, channel_last=False, residual=None):
  function base_conv3d (line 74) | def base_conv3d(x, conv_layer, channel_last=False, residual=None, only_r...
  function cal_outsize (line 90) | def cal_outsize(input_sizes, kernel_sizes, stride, padding):
  function calc_out_ (line 115) | def calc_out_(in_size, padding, dilation, kernel, stride):
  function base_conv3d_channel_last (line 120) | def base_conv3d_channel_last(x, conv_layer, residual=None):
  class Upsample2D (line 160) | class Upsample2D(nn.Module):
    method __init__ (line 161) | def __init__(self,
    method forward (line 178) | def forward(self, x, output_size=None):
  class Downsample2D (line 198) | class Downsample2D(nn.Module):
    method __init__ (line 199) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
    method forward (line 213) | def forward(self, x):
  class CausalConv (line 226) | class CausalConv(nn.Module):
    method __init__ (line 227) | def __init__(self,
    method forward (line 252) | def forward(self, x, is_init=True, residual=None):
  class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D (line 262) | class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module):
    method __init__ (line 263) | def __init__(
    method forward (line 276) | def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
  class ConvPixelShuffleUpSampleLayer3D (line 284) | class ConvPixelShuffleUpSampleLayer3D(nn.Module):
    method __init__ (line 285) | def __init__(
    method forward (line 301) | def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
    method pixel_shuffle_3d (line 307) | def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
  class ConvPixelUnshuffleDownSampleLayer3D (line 320) | class ConvPixelUnshuffleDownSampleLayer3D(nn.Module):
    method __init__ (line 321) | def __init__(
    method forward (line 338) | def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
    method pixel_unshuffle_3d (line 344) | def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
  class PixelUnshuffleChannelAveragingDownSampleLayer3D (line 353) | class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module):
    method __init__ (line 354) | def __init__(
    method forward (line 367) | def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
    method __init__ (line 378) | def __init__(
    method forward (line 391) | def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
  function base_group_norm_with_zero_pad (line 405) | def base_group_norm_with_zero_pad(x, norm_layer, act_silu=True, pad_size...
  class CausalConvChannelLast (line 414) | class CausalConvChannelLast(CausalConv):
    method __init__ (line 415) | def __init__(self,
    method forward (line 427) | def forward(self, x, is_init=True, residual=None):
  class CausalConvAfterNorm (line 438) | class CausalConvAfterNorm(CausalConv):
    method __init__ (line 439) | def __init__(self,
    method forward (line 454) | def forward(self, x, is_init=True, residual=None):
  class AttnBlock (line 466) | class AttnBlock(nn.Module):
    method __init__ (line 467) | def __init__(self,
    method attention (line 478) | def attention(self, x, is_init=True):
    method forward (line 491) | def forward(self, x):
  class Resnet3DBlock (line 498) | class Resnet3DBlock(nn.Module):
    method __init__ (line 499) | def __init__(self,
    method forward (line 527) | def forward(self, x, temb=None, is_init=True):
  class Downsample3D (line 544) | class Downsample3D(nn.Module):
    method __init__ (line 545) | def __init__(self,
    method forward (line 556) | def forward(self, x, is_init=True):
  class VideoEncoder (line 563) | class VideoEncoder(nn.Module):
    method __init__ (line 564) | def __init__(self,
    method forward (line 626) | def forward(self, x, video_frame_num, is_init=True):
  class Res3DBlockUpsample (line 676) | class Res3DBlockUpsample(nn.Module):
    method __init__ (line 677) | def __init__(self,
    method forward (line 706) | def forward(self, x, is_init=False):
  class Upsample3D (line 729) | class Upsample3D(nn.Module):
    method __init__ (line 730) | def __init__(self,
    method forward (line 742) | def forward(self, x, is_init=True, is_split=True):
  class VideoDecoder (line 757) | class VideoDecoder(nn.Module):
    method __init__ (line 758) | def __init__(self,
    method forward (line 823) | def forward(self, z, is_init=True):
  function rms_norm (line 871) | def rms_norm(input, normalized_shape, eps=1e-6):
  class DiagonalGaussianDistribution (line 878) | class DiagonalGaussianDistribution(object):
    method __init__ (line 879) | def __init__(self, parameters, deterministic=False, rms_norm_mean=Fals...
    method sample (line 895) | def sample(self, generator=None):
  class StepVideoVAE (line 908) | class StepVideoVAE(nn.Module):
    method __init__ (line 909) | def __init__(self,
    method init_from_ckpt (line 948) | def init_from_ckpt(self, model_path):
    method load_from_dict (line 959) | def load_from_dict(self, p):
    method convert_channel_last (line 962) | def convert_channel_last(self):
    method naive_encode (line 966) | def naive_encode(self, x, is_init_image=True):
    method encode (line 973) | def encode(self, x):
    method decode_naive (line 983) | def decode_naive(self, z, is_init=True):
    method decode_original (line 989) | def decode_original(self, z):
    method mix (line 1014) | def mix(self, x, smooth_scale = 0.6):
    method single_decode (line 1025) | def single_decode(self, hidden_states, device):
    method build_1d_mask (line 1032) | def build_1d_mask(self, length, left_bound, right_bound, border_width):
    method build_mask (line 1040) | def build_mask(self, data, is_bound, border_width):
    method tiled_decode (line 1052) | def tiled_decode(self, hidden_states, device, tile_size=(34, 34), tile...
    method decode (line 1103) | def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34...
    method state_dict_converter (line 1113) | def state_dict_converter():
  class StepVideoVAEStateDictConverter (line 1117) | class StepVideoVAEStateDictConverter:
    method __init__ (line 1118) | def __init__(self):
    method from_diffusers (line 1121) | def from_diffusers(self, state_dict):
    method from_civitai (line 1124) | def from_civitai(self, state_dict):

FILE: diffsynth/models/svd_image_encoder.py
  class CLIPVisionEmbeddings (line 5) | class CLIPVisionEmbeddings(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_...
    method forward (line 18) | def forward(self, pixel_values):
  class SVDImageEncoder (line 27) | class SVDImageEncoder(torch.nn.Module):
    method __init__ (line 28) | def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_la...
    method forward (line 38) | def forward(self, pixel_values):
    method state_dict_converter (line 48) | def state_dict_converter():
  class SVDImageEncoderStateDictConverter (line 52) | class SVDImageEncoderStateDictConverter:
    method __init__ (line 53) | def __init__(self):
    method from_diffusers (line 56) | def from_diffusers(self, state_dict):
    method from_civitai (line 94) | def from_civitai(self, state_dict):

FILE: diffsynth/models/svd_unet.py
  class TemporalResnetBlock (line 6) | class TemporalResnetBlock(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, in_channels, out_channels, temb_channels=None, grou...
    method forward (line 20) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  function get_timestep_embedding (line 40) | def get_timestep_embedding(
  class TemporalTimesteps (line 84) | class TemporalTimesteps(torch.nn.Module):
    method __init__ (line 85) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
    method forward (line 92) | def forward(self, timesteps):
  class TrainableTemporalTimesteps (line 103) | class TrainableTemporalTimesteps(torch.nn.Module):
    method __init__ (line 104) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
    method forward (line 110) | def forward(self, timesteps):
  class PositionalID (line 115) | class PositionalID(torch.nn.Module):
    method __init__ (line 116) | def __init__(self, max_id=25, repeat_length=20):
    method frame_id_to_position_id (line 121) | def frame_id_to_position_id(self, frame_id):
    method forward (line 132) | def forward(self, num_frames, pivot_frame_id=0):
  class TemporalAttentionBlock (line 138) | class TemporalAttentionBlock(torch.nn.Module):
    method __init__ (line 140) | def __init__(self, num_attention_heads, attention_head_dim, in_channel...
    method forward (line 180) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class PopMixBlock (line 217) | class PopMixBlock(torch.nn.Module):
    method __init__ (line 218) | def __init__(self, in_channels=None):
    method forward (line 225) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class SVDUNet (line 238) | class SVDUNet(torch.nn.Module):
    method __init__ (line 239) | def __init__(self, add_positional_conv=None):
    method build_mask (line 316) | def build_mask(self, data, is_bound):
    method tiled_forward (line 337) | def tiled_forward(
    method forward (line 373) | def forward(self, sample, timestep, encoder_hidden_states, add_time_id...
    method state_dict_converter (line 414) | def state_dict_converter():
  class SVDUNetStateDictConverter (line 419) | class SVDUNetStateDictConverter:
    method __init__ (line 420) | def __init__(self):
    method get_block_name (line 423) | def get_block_name(self, names):
    method from_diffusers (line 435) | def from_diffusers(self, state_dict):
    method from_civitai (line 555) | def from_civitai(self, state_dict, add_positional_conv=None):

FILE: diffsynth/models/svd_vae_decoder.py
  class VAEAttentionBlock (line 8) | class VAEAttentionBlock(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, num_attention_heads, attention_head_dim, in_channel...
    method forward (line 28) | def forward(self, hidden_states, time_emb, text_emb, res_stack):
  class TemporalResnetBlock (line 45) | class TemporalResnetBlock(torch.nn.Module):
    method __init__ (line 47) | def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
    method forward (line 56) | def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwar...
  class SVDVAEDecoder (line 71) | class SVDVAEDecoder(torch.nn.Module):
    method __init__ (line 72) | def __init__(self):
    method forward (line 123) | def forward(self, sample):
    method build_mask (line 144) | def build_mask(self, data, is_bound):
    method decode_video (line 165) | def decode_video(
    method state_dict_converter (line 203) | def state_dict_converter():
  class SVDVAEDecoderStateDictConverter (line 207) | class SVDVAEDecoderStateDictConverter:
    method __init__ (line 208) | def __init__(self):
    method from_diffusers (line 211) | def from_diffusers(self, state_dict):
    method from_civitai (line 302) | def from_civitai(self, state_dict):

FILE: diffsynth/models/svd_vae_encoder.py
  class SVDVAEEncoder (line 4) | class SVDVAEEncoder(SDVAEEncoder):
    method __init__ (line 5) | def __init__(self):
    method state_dict_converter (line 10) | def state_dict_converter():
  class SVDVAEEncoderStateDictConverter (line 14) | class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
    method __init__ (line 15) | def __init__(self):
    method from_diffusers (line 18) | def from_diffusers(self, state_dict):
    method from_civitai (line 21) | def from_civitai(self, state_dict):

FILE: diffsynth/models/tiler.py
  class TileWorker (line 5) | class TileWorker:
    method __init__ (line 6) | def __init__(self):
    method mask (line 10) | def mask(self, height, width, border_width):
    method tile (line 20) | def tile(self, model_input, tile_size, tile_stride, tile_device, tile_...
    method tiled_inference (line 34) | def tiled_inference(self, forward_fn, model_input, tile_batch_size, in...
    method io_scale (line 57) | def io_scale(self, model_output, tile_size):
    method untile (line 64) | def untile(self, model_output, height, width, tile_size, tile_stride, ...
    method tiled_forward (line 83) | def tiled_forward(self, forward_fn, model_input, tile_size, tile_strid...
  class FastTileWorker (line 110) | class FastTileWorker:
    method __init__ (line 111) | def __init__(self):
    method build_mask (line 115) | def build_mask(self, data, is_bound):
    method tiled_forward (line 133) | def tiled_forward(self, forward_fn, model_input, tile_size, tile_strid...
  class TileWorker2Dto3D (line 164) | class TileWorker2Dto3D:
    method __init__ (line 168) | def __init__(self):
    method build_mask (line 172) | def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
    method tiled_forward (line 192) | def tiled_forward(

FILE: diffsynth/models/utils.py
  function init_weights_on_device (line 7) | def init_weights_on_device(device = torch.device("meta"), include_buffer...
  function load_state_dict_from_folder (line 55) | def load_state_dict_from_folder(file_path, torch_dtype=None):
  function load_state_dict (line 65) | def load_state_dict(file_path, torch_dtype=None):
  function load_state_dict_from_safetensors (line 72) | def load_state_dict_from_safetensors(file_path, torch_dtype=None):
  function load_state_dict_from_bin (line 82) | def load_state_dict_from_bin(file_path, torch_dtype=None):
  function search_for_embeddings (line 91) | def search_for_embeddings(state_dict):
  function search_parameter (line 101) | def search_parameter(param, state_dict):
  function build_rename_dict (line 113) | def build_rename_dict(source_state_dict, target_state_dict, split_qkv=Fa...
  function search_for_files (line 135) | def search_for_files(folder, extensions):
  function convert_state_dict_keys_to_single_str (line 148) | def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
  function split_state_dict_with_prefix (line 164) | def split_state_dict_with_prefix(state_dict):
  function hash_state_dict_keys (line 179) | def hash_state_dict_keys(state_dict, with_shape=True):

FILE: diffsynth/models/wan_video_dit.py
  function flash_attention (line 28) | def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, n...
  function modulate (line 62) | def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
  function sinusoidal_embedding_1d (line 66) | def sinusoidal_embedding_1d(dim, position):
  function precompute_freqs_cis_3d (line 73) | def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10...
  function precompute_freqs_cis (line 81) | def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000...
  function rope_apply (line 90) | def rope_apply(x, freqs, num_heads):
  class RMSNorm (line 98) | class RMSNorm(nn.Module):
    method __init__ (line 99) | def __init__(self, dim, eps=1e-5):
    method norm (line 104) | def norm(self, x):
    method forward (line 107) | def forward(self, x):
  class AttentionModule (line 112) | class AttentionModule(nn.Module):
    method __init__ (line 113) | def __init__(self, num_heads):
    method forward (line 117) | def forward(self, q, k, v):
  class SelfAttention (line 122) | class SelfAttention(nn.Module):
    method __init__ (line 123) | def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
    method forward (line 138) | def forward(self, x, freqs):
  class CrossAttention (line 148) | class CrossAttention(nn.Module):
    method __init__ (line 149) | def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_im...
    method forward (line 169) | def forward(self, x: torch.Tensor, y: torch.Tensor):
  class DiTBlock (line 187) | class DiTBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, has_image_input: bool, dim: int, num_heads: int, ff...
    method forward (line 204) | def forward(self, x, context, cam_emb, t_mod, freqs, freqs_mvs):
  class MLP (line 243) | class MLP(torch.nn.Module):
    method __init__ (line 244) | def __init__(self, in_dim, out_dim):
    method forward (line 254) | def forward(self, x):
  class Head (line 258) | class Head(nn.Module):
    method __init__ (line 259) | def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int,...
    method forward (line 267) | def forward(self, x, t_mod):
  class WanModel (line 273) | class WanModel(torch.nn.Module):
    method __init__ (line 274) | def __init__(
    method patchify (line 319) | def patchify(self, x: torch.Tensor):
    method unpatchify (line 325) | def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
    method forward (line 332) | def forward(self,
    method state_dict_converter (line 396) | def state_dict_converter():
  class WanModelStateDictConverter (line 400) | class WanModelStateDictConverter:
    method __init__ (line 401) | def __init__(self):
    method from_diffusers (line 404) | def from_diffusers(self, state_dict):
    method from_civitai (line 481) | def from_civitai(self, state_dict):

FILE: diffsynth/models/wan_video_image_encoder.py
  class SelfAttention (line 14) | class SelfAttention(nn.Module):
    method __init__ (line 16) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
    method forward (line 31) | def forward(self, x, mask):
    method __init__ (line 236) | def __init__(self,
    method forward (line 255) | def forward(self, x):
  class AttentionBlock (line 53) | class AttentionBlock(nn.Module):
    method __init__ (line 55) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
    method forward (line 70) | def forward(self, x, mask):
    method __init__ (line 291) | def __init__(self,
    method forward (line 323) | def forward(self, x):
  class XLMRoberta (line 80) | class XLMRoberta(nn.Module):
    method __init__ (line 85) | def __init__(self,
    method forward (line 122) | def forward(self, ids):
  function xlm_roberta_large (line 150) | def xlm_roberta_large(pretrained=False,
  function pos_interpolate (line 203) | def pos_interpolate(pos, seq_len):
  class QuickGELU (line 222) | class QuickGELU(nn.Module):
    method forward (line 224) | def forward(self, x):
  class LayerNorm (line 228) | class LayerNorm(nn.LayerNorm):
    method forward (line 230) | def forward(self, x):
  class SelfAttention (line 234) | class SelfAttention(nn.Module):
    method __init__ (line 16) | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
    method forward (line 31) | def forward(self, x, mask):
    method __init__ (line 236) | def __init__(self,
    method forward (line 255) | def forward(self, x):
  class SwiGLU (line 271) | class SwiGLU(nn.Module):
    method __init__ (line 273) | def __init__(self, dim, mid_dim):
    method forward (line 283) | def forward(self, x):
  class AttentionBlock (line 289) | class AttentionBlock(nn.Module):
    method __init__ (line 55) | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
    method forward (line 70) | def forward(self, x, mask):
    method __init__ (line 291) | def __init__(self,
    method forward (line 323) | def forward(self, x):
  class AttentionPool (line 333) | class AttentionPool(nn.Module):
    method __init__ (line 335) | def __init__(self,
    method forward (line 363) | def forward(self, x):
  class VisionTransformer (line 386) | class VisionTransformer(nn.Module):
    method __init__ (line 388) | def __init__(self,
    method forward (line 456) | def forward(self, x, interpolation=False, use_31_block=False):
  class CLIP (line 481) | class CLIP(nn.Module):
    method __init__ (line 483) | def __init__(self,
    method forward (line 571) | def forward(self, imgs, txt_ids):
    method init_weights (line 582) | def init_weights(self):
    method param_groups (line 601) | def param_groups(self):
  class XLMRobertaWithHead (line 617) | class XLMRobertaWithHead(XLMRoberta):
    method __init__ (line 619) | def __init__(self, **kwargs):
    method forward (line 629) | def forward(self, ids):
  class XLMRobertaCLIP (line 642) | class XLMRobertaCLIP(nn.Module):
    method __init__ (line 644) | def __init__(self,
    method forward (line 710) | def forward(self, imgs, txt_ids):
    method param_groups (line 722) | def param_groups(self):
  function _clip (line 738) | def _clip(pretrained=False,
  function clip_xlm_roberta_vit_h_14 (line 822) | def clip_xlm_roberta_vit_h_14(
  class WanImageEncoder (line 852) | class WanImageEncoder(torch.nn.Module):
    method __init__ (line 854) | def __init__(self):
    method encode_image (line 864) | def encode_image(self, videos):
    method state_dict_converter (line 883) | def state_dict_converter():
  class WanImageEncoderStateDictConverter (line 887) | class WanImageEncoderStateDictConverter:
    method __init__ (line 888) | def __init__(self):
    method from_diffusers (line 891) | def from_diffusers(self, state_dict):
    method from_civitai (line 894) | def from_civitai(self, state_dict):

FILE: diffsynth/models/wan_video_text_encoder.py
  function fp16_clamp (line 8) | def fp16_clamp(x):
  class GELU (line 15) | class GELU(nn.Module):
    method forward (line 17) | def forward(self, x):
  class T5LayerNorm (line 22) | class T5LayerNorm(nn.Module):
    method __init__ (line 24) | def __init__(self, dim, eps=1e-6):
    method forward (line 30) | def forward(self, x):
  class T5Attention (line 38) | class T5Attention(nn.Module):
    method __init__ (line 40) | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
    method forward (line 55) | def forward(self, x, context=None, mask=None, pos_bias=None):
  class T5FeedForward (line 92) | class T5FeedForward(nn.Module):
    method __init__ (line 94) | def __init__(self, dim, dim_ffn, dropout=0.1):
    method forward (line 105) | def forward(self, x):
  class T5SelfAttention (line 113) | class T5SelfAttention(nn.Module):
    method __init__ (line 115) | def __init__(self,
    method forward (line 139) | def forward(self, x, mask=None, pos_bias=None):
  class T5RelativeEmbedding (line 147) | class T5RelativeEmbedding(nn.Module):
    method __init__ (line 149) | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
    method forward (line 159) | def forward(self, lq, lk):
    method _relative_position_bucket (line 171) | def _relative_position_bucket(self, rel_pos):
  function init_weights (line 192) | def init_weights(m):
  class WanTextEncoder (line 209) | class WanTextEncoder(torch.nn.Module):
    method __init__ (line 211) | def __init__(self,
    method forward (line 245) | def forward(self, ids, mask=None):
    method state_dict_converter (line 257) | def state_dict_converter():
  class WanTextEncoderStateDictConverter (line 261) | class WanTextEncoderStateDictConverter:
    method __init__ (line 262) | def __init__(self):
    method from_diffusers (line 265) | def from_diffusers(self, state_dict):
    method from_civitai (line 268) | def from_civitai(self, state_dict):

FILE: diffsynth/models/wan_video_vae.py
  function check_is_instance (line 11) | def check_is_instance(model, module_class):
  function block_causal_mask (line 19) | def block_causal_mask(x, block_size):
  class CausalConv3d (line 33) | class CausalConv3d(nn.Conv3d):
    method __init__ (line 38) | def __init__(self, *args, **kwargs):
    method forward (line 44) | def forward(self, x, cache_x=None):
  class RMS_norm (line 55) | class RMS_norm(nn.Module):
    method __init__ (line 57) | def __init__(self, dim, channel_first=True, images=True, bias=False):
    method forward (line 67) | def forward(self, x):
  class Upsample (line 73) | class Upsample(nn.Upsample):
    method forward (line 75) | def forward(self, x):
  class Resample (line 82) | class Resample(nn.Module):
    method __init__ (line 84) | def __init__(self, dim, mode):
    method forward (line 120) | def forward(self, x, feat_cache=None, feat_idx=[0]):
    method init_weight (line 176) | def init_weight(self, conv):
    method init_weight2 (line 187) | def init_weight2(self, conv):
  class ResidualBlock (line 198) | class ResidualBlock(nn.Module):
    method __init__ (line 200) | def __init__(self, in_dim, out_dim, dropout=0.0):
    method forward (line 214) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class AttentionBlock (line 235) | class AttentionBlock(nn.Module):
    method __init__ (line 240) | def __init__(self, dim):
    method forward (line 252) | def forward(self, x):
  class Encoder3d (line 276) | class Encoder3d(nn.Module):
    method __init__ (line 278) | def __init__(self,
    method forward (line 328) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  class Decoder3d (line 379) | class Decoder3d(nn.Module):
    method __init__ (line 381) | def __init__(self,
    method forward (line 432) | def forward(self, x, feat_cache=None, feat_idx=[0]):
  function count_conv3d (line 484) | def count_conv3d(model):
  class VideoVAE_ (line 492) | class VideoVAE_(nn.Module):
    method __init__ (line 494) | def __init__(self,
    method forward (line 519) | def forward(self, x):
    method encode (line 525) | def encode(self, x, scale):
    method decode (line 552) | def decode(self, z, scale):
    method reparameterize (line 577) | def reparameterize(self, mu, log_var):
    method sample (line 582) | def sample(self, imgs, deterministic=False):
    method clear_cache (line 589) | def clear_cache(self):
  class WanVideoVAE (line 599) | class WanVideoVAE(nn.Module):
    method __init__ (line 601) | def __init__(self, z_dim=16):
    method build_1d_mask (line 621) | def build_1d_mask(self, length, left_bound, right_bound, border_width):
    method build_mask (line 630) | def build_mask(self, data, is_bound, border_width):
    method tiled_decode (line 643) | def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
    method tiled_encode (line 695) | def tiled_encode(self, video, device, tile_size, tile_stride):
    method single_encode (line 746) | def single_encode(self, video, device):
    method single_decode (line 752) | def single_decode(self, hidden_state, device):
    method encode (line 758) | def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile...
    method decode (line 776) | def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34...
    method state_dict_converter (line 792) | def state_dict_converter():
  class WanVideoVAEStateDictConverter (line 796) | class WanVideoVAEStateDictConverter:
    method __init__ (line 798) | def __init__(self):
    method from_civitai (line 801) | def from_civitai(self, state_dict):

FILE: diffsynth/pipelines/base.py
  class BasePipeline (line 8) | class BasePipeline(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, device="cuda", torch_dtype=torch.float16, height_di...
    method check_resize_height_width (line 20) | def check_resize_height_width(self, height, width):
    method preprocess_image (line 30) | def preprocess_image(self, image):
    method preprocess_images (line 35) | def preprocess_images(self, images):
    method vae_output_to_image (line 39) | def vae_output_to_image(self, vae_output):
    method vae_output_to_video (line 45) | def vae_output_to_video(self, vae_output):
    method merge_latents (line 51) | def merge_latents(self, value, latents, masks, scales, blur_kernel_siz...
    method control_noise_via_local_prompts (line 66) | def control_noise_via_local_prompts(self, prompt_emb_global, prompt_em...
    method extend_prompt (line 79) | def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
    method enable_cpu_offload (line 91) | def enable_cpu_offload(self):
    method load_models_to_device (line 95) | def load_models_to_device(self, loadmodel_names=[]):
    method generate_noise (line 124) | def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.f...

FILE: diffsynth/pipelines/cog_video.py
  class CogVideoPipeline (line 13) | class CogVideoPipeline(BasePipeline):
    method __init__ (line 15) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method fetch_models (line 26) | def fetch_models(self, model_manager: ModelManager, prompt_refiner_cla...
    method from_model_manager (line 36) | def from_model_manager(model_manager: ModelManager, prompt_refiner_cla...
    method tensor2video (line 45) | def tensor2video(self, frames):
    method encode_prompt (line 52) | def encode_prompt(self, prompt, positive=True):
    method prepare_extra_input (line 57) | def prepare_extra_input(self, latents):
    method __call__ (line 62) | def __call__(

FILE: diffsynth/pipelines/dancer.py
  function lets_dance (line 7) | def lets_dance(
  function lets_dance_xl (line 119) | def lets_dance_xl(

FILE: diffsynth/pipelines/flux_image.py
  class FluxImagePipeline (line 19) | class FluxImagePipeline(BasePipeline):
    method __init__ (line 21) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method enable_vram_management (line 37) | def enable_vram_management(self, num_persistent_param_in_dit=None):
    method denoising_model (line 136) | def denoising_model(self):
    method fetch_models (line 140) | def fetch_models(self, model_manager: ModelManager, controlnet_config_...
    method from_model_manager (line 167) | def from_model_manager(model_manager: ModelManager, controlnet_config_...
    method encode_image (line 176) | def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
    method decode_image (line 181) | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=...
    method encode_prompt (line 187) | def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
    method prepare_extra_input (line 194) | def prepare_extra_input(self, latents=None, guidance=1.0):
    method apply_controlnet_mask_on_latents (line 200) | def apply_controlnet_mask_on_latents(self, latents, mask):
    method apply_controlnet_mask_on_image (line 209) | def apply_controlnet_mask_on_image(self, image, mask):
    method prepare_controlnet_input (line 218) | def prepare_controlnet_input(self, controlnet_image, controlnet_inpain...
    method prepare_ipadapter_inputs (line 242) | def prepare_ipadapter_inputs(self, images, height=384, width=384):
    method inpaint_fusion (line 248) | def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask...
    method preprocess_masks (line 260) | def preprocess_masks(self, masks, height, width, dim):
    method prepare_entity_inputs (line 269) | def prepare_entity_inputs(self, entity_prompts, entity_masks, width, h...
    method prepare_latents (line 283) | def prepare_latents(self, input_image, height, width, seed, tiled, til...
    method prepare_ipadapter (line 296) | def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
    method prepare_controlnet (line 309) | def prepare_controlnet(self, controlnet_image, masks, controlnet_inpai...
    method prepare_eligen (line 324) | def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, elige...
    method prepare_prompts (line 340) | def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t...
    method __call__ (line 353) | def __call__(
  class TeaCache (line 472) | class TeaCache:
    method __init__ (line 473) | def __init__(self, num_inference_steps, rel_l1_thresh):
    method check (line 482) | def check(self, dit: FluxDiT, hidden_states, conditioning):
    method store (line 506) | def store(self, hidden_states):
    method update (line 510) | def update(self, hidden_states):
  function lets_dance_flux (line 515) | def lets_dance_flux(

FILE: diffsynth/pipelines/hunyuan_image.py
  class ImageSizeManager (line 15) | class ImageSizeManager:
    method __init__ (line 16) | def __init__(self):
    method _to_tuple (line 20) | def _to_tuple(self, x):
    method get_fill_resize_and_crop (line 27) | def get_fill_resize_and_crop(self, src, tgt):
    method get_meshgrid (line 48) | def get_meshgrid(self, start, *args):
    method get_2d_rotary_pos_embed (line 74) | def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=Tr...
    method get_2d_rotary_pos_embed_from_grid (line 81) | def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=...
    method get_1d_rotary_pos_embed (line 97) | def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000....
    method calc_rope (line 112) | def calc_rope(self, height, width):
  class HunyuanDiTImagePipeline (line 125) | class HunyuanDiTImagePipeline(BasePipeline):
    method __init__ (line 127) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method denoising_model (line 141) | def denoising_model(self):
    method fetch_models (line 145) | def fetch_models(self, model_manager: ModelManager, prompt_refiner_cla...
    method from_model_manager (line 157) | def from_model_manager(model_manager: ModelManager, prompt_refiner_cla...
    method encode_image (line 166) | def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
    method decode_image (line 171) | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=...
    method encode_prompt (line 177) | def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=T...
    method prepare_extra_input (line 193) | def prepare_extra_input(self, latents=None, tiled=False, tile_size=64,...
    method __call__ (line 210) | def __call__(

FILE: diffsynth/pipelines/hunyuan_video.py
  class HunyuanVideoPipeline (line 15) | class HunyuanVideoPipeline(BasePipeline):
    method __init__ (line 17) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method enable_vram_management (line 30) | def enable_vram_management(self):
    method fetch_models (line 37) | def fetch_models(self, model_manager: ModelManager):
    method from_model_manager (line 47) | def from_model_manager(model_manager: ModelManager, torch_dtype=None, ...
    method generate_crop_size_list (line 56) | def generate_crop_size_list(self, base_size=256, patch_size=32, max_ra...
    method get_closest_ratio (line 71) | def get_closest_ratio(self, height: float, width: float, ratios: list,...
    method prepare_vae_images_inputs (line 78) | def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="7...
    method encode_prompt (line 105) | def encode_prompt(self, prompt, positive=True, clip_sequence_length=77...
    method prepare_extra_input (line 112) | def prepare_extra_input(self, latents=None, guidance=1.0):
    method tensor2video (line 118) | def tensor2video(self, frames):
    method encode_video (line 125) | def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12...
    method __call__ (line 133) | def __call__(
  class TeaCache (line 251) | class TeaCache:
    method __init__ (line 252) | def __init__(self, num_inference_steps, rel_l1_thresh):
    method check (line 261) | def check(self, dit: HunyuanVideoDiT, img, vec):
    method store (line 287) | def store(self, hidden_states):
    method update (line 291) | def update(self, hidden_states):
  function lets_dance_hunyuan_video (line 297) | def lets_dance_hunyuan_video(
  function lets_dance_hunyuan_video_i2v (line 343) | def lets_dance_hunyuan_video_i2v(

FILE: diffsynth/pipelines/omnigen_image.py
  class OmniGenCache (line 15) | class OmniGenCache(DynamicCache):
    method __init__ (line 16) | def __init__(self,
    method prefetch_layer (line 28) | def prefetch_layer(self, layer_idx: int):
    method evict_previous_layer (line 38) | def evict_previous_layer(self, layer_idx: int):
    method __getitem__ (line 50) | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
    method update (line 74) | def update(
  class OmnigenImagePipeline (line 122) | class OmnigenImagePipeline(BasePipeline):
    method __init__ (line 124) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method denoising_model (line 135) | def denoising_model(self):
    method fetch_models (line 139) | def fetch_models(self, model_manager: ModelManager, prompt_refiner_cla...
    method from_model_manager (line 148) | def from_model_manager(model_manager: ModelManager, prompt_refiner_cla...
    method encode_image (line 157) | def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
    method encode_images (line 162) | def encode_images(self, images, tiled=False, tile_size=64, tile_stride...
    method decode_image (line 167) | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=...
    method encode_prompt (line 173) | def encode_prompt(self, prompt, clip_skip=1, positive=True):
    method prepare_extra_input (line 178) | def prepare_extra_input(self, latents=None):
    method crop_position_ids_for_cache (line 182) | def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
    method crop_attention_mask_for_cache (line 191) | def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for...
    method __call__ (line 198) | def __call__(

FILE: diffsynth/pipelines/pipeline_runner.py
  class SDVideoPipelineRunner (line 8) | class SDVideoPipelineRunner:
    method __init__ (line 9) | def __init__(self, in_streamlit=False):
    method load_pipeline (line 13) | def load_pipeline(self, model_list, textual_inversion_folder, device, ...
    method load_smoother (line 35) | def load_smoother(self, model_manager, smoother_configs):
    method synthesize_video (line 40) | def synthesize_video(self, model_manager, pipe, seed, smoother, **pipe...
    method load_video (line 53) | def load_video(self, video_file, image_folder, height, width, start_fr...
    method add_data_to_pipeline_inputs (line 63) | def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
    method save_output (line 72) | def save_output(self, video, output_folder, fps, config):
    method run (line 82) | def run(self, config):

FILE: diffsynth/pipelines/sd3_image.py
  class SD3ImagePipeline (line 10) | class SD3ImagePipeline(BasePipeline):
    method __init__ (line 12) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method denoising_model (line 26) | def denoising_model(self):
    method fetch_models (line 30) | def fetch_models(self, model_manager: ModelManager, prompt_refiner_cla...
    method from_model_manager (line 42) | def from_model_manager(model_manager: ModelManager, prompt_refiner_cla...
    method encode_image (line 51) | def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
    method decode_image (line 56) | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=...
    method encode_prompt (line 62) | def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
    method prepare_extra_input (line 69) | def prepare_extra_input(self, latents=None):
    method __call__ (line 74) | def __call__(

FILE: diffsynth/pipelines/sd_image.py
  class SDImagePipeline (line 14) | class SDImagePipeline(BasePipeline):
    method __init__ (line 16) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method denoising_model (line 31) | def denoising_model(self):
    method fetch_models (line 35) | def fetch_models(self, model_manager: ModelManager, controlnet_config_...
    method from_model_manager (line 61) | def from_model_manager(model_manager: ModelManager, controlnet_config_...
    method encode_image (line 70) | def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
    method decode_image (line 75) | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=...
    method encode_prompt (line 81) | def encode_prompt(self, prompt, clip_skip=1, positive=True):
    method prepare_extra_input (line 86) | def prepare_extra_input(self, latents=None):
    method __call__ (line 91) | def __call__(

FILE: diffsynth/pipelines/sd_video.py
  function lets_dance_with_long_video (line 14) | def lets_dance_with_long_video(
  class SDVideoPipeline (line 68) | class SDVideoPipeline(SDImagePipeline):
    method __init__ (line 70) | def __init__(self, device="cuda", torch_dtype=torch.float16, use_origi...
    method fetch_models (line 85) | def fetch_models(self, model_manager: ModelManager, controlnet_config_...
    method from_model_manager (line 116) | def from_model_manager(model_manager: ModelManager, controlnet_config_...
    method decode_video (line 125) | def decode_video(self, latents, tiled=False, tile_size=64, tile_stride...
    method encode_video (line 133) | def encode_video(self, processed_images, tiled=False, tile_size=64, ti...
    method __call__ (line 144) | def __call__(

FILE: diffsynth/pipelines/sdxl_image.py
  class SDXLImagePipeline (line 16) | class SDXLImagePipeline(BasePipeline):
    method __init__ (line 18) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method denoising_model (line 35) | def denoising_model(self):
    method fetch_models (line 39) | def fetch_models(self, model_manager: ModelManager, controlnet_config_...
    method from_model_manager (line 75) | def from_model_manager(model_manager: ModelManager, controlnet_config_...
    method encode_image (line 84) | def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
    method decode_image (line 89) | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=...
    method encode_prompt (line 95) | def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=T...
    method prepare_extra_input (line 105) | def prepare_extra_input(self, latents=None):
    method __call__ (line 112) | def __call__(

FILE: diffsynth/pipelines/sdxl_video.py
  class SDXLVideoPipeline (line 15) | class SDXLVideoPipeline(SDXLImagePipeline):
    method __init__ (line 17) | def __init__(self, device="cuda", torch_dtype=torch.float16, use_origi...
    method fetch_models (line 34) | def fetch_models(self, model_manager: ModelManager, controlnet_config_...
    method from_model_manager (line 69) | def from_model_manager(model_manager: ModelManager, controlnet_config_...
    method decode_video (line 78) | def decode_video(self, latents, tiled=False, tile_size=64, tile_stride...
    method encode_video (line 86) | def encode_video(self, processed_images, tiled=False, tile_size=64, ti...
    method __call__ (line 97) | def __call__(

FILE: diffsynth/pipelines/step_video.py
  class StepVideoPipeline (line 20) | class StepVideoPipeline(BasePipeline):
    method __init__ (line 22) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method enable_vram_management (line 33) | def enable_vram_management(self, num_persistent_param_in_dit=None):
    method fetch_models (line 118) | def fetch_models(self, model_manager: ModelManager):
    method from_model_manager (line 127) | def from_model_manager(model_manager: ModelManager, torch_dtype=None, ...
    method encode_prompt (line 135) | def encode_prompt(self, prompt, positive=True):
    method tensor2video (line 143) | def tensor2video(self, frames):
    method __call__ (line 151) | def __call__(

FILE: diffsynth/pipelines/svd_video.py
  class SVDVideoPipeline (line 12) | class SVDVideoPipeline(BasePipeline):
    method __init__ (line 14) | def __init__(self, device="cuda", torch_dtype=torch.float16):
    method fetch_models (line 24) | def fetch_models(self, model_manager: ModelManager):
    method from_model_manager (line 32) | def from_model_manager(model_manager: ModelManager, **kwargs):
    method encode_image_with_clip (line 41) | def encode_image_with_clip(self, image):
    method encode_image_with_vae (line 52) | def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
    method encode_video_with_vae (line 60) | def encode_video_with_vae(self, video):
    method tensor2video (line 69) | def tensor2video(self, frames):
    method calculate_noise_pred (line 76) | def calculate_noise_pred(
    method post_process_latents (line 102) | def post_process_latents(self, latents, post_normalize=True, contrast_...
    method __call__ (line 111) | def __call__(
  class SVDCLIPImageProcessor (line 192) | class SVDCLIPImageProcessor:
    method __init__ (line 193) | def __init__(self):
    method resize_with_antialiasing (line 196) | def resize_with_antialiasing(self, input, size, interpolation="bicubic...
    method _compute_padding (line 225) | def _compute_padding(self, kernel_size):
    method _filter2d (line 248) | def _filter2d(self, input, kernel):
    method _gaussian (line 271) | def _gaussian(self, window_size: int, sigma):
    method _gaussian_blur2d (line 287) | def _gaussian_blur2d(self, input, kernel_size, sigma):

FILE: diffsynth/pipelines/wan_video.py
  class WanVideoPipeline (line 23) | class WanVideoPipeline(BasePipeline):
    method __init__ (line 25) | def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer...
    method enable_vram_management (line 38) | def enable_vram_management(self, num_persistent_param_in_dit=None):
    method fetch_models (line 126) | def fetch_models(self, model_manager: ModelManager):
    method from_model_manager (line 138) | def from_model_manager(model_manager: ModelManager, torch_dtype=None, ...
    method denoising_model (line 146) | def denoising_model(self):
    method encode_prompt (line 150) | def encode_prompt(self, prompt, positive=True):
    method encode_image (line 155) | def encode_image(self, image, num_frames, height, width):
    method tensor2video (line 173) | def tensor2video(self, frames):
    method prepare_extra_input (line 180) | def prepare_extra_input(self, latents=None):
    method encode_video (line 184) | def encode_video(self, input_video, tiled=True, tile_size=(34, 34), ti...
    method decode_video (line 189) | def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_s...
    method __call__ (line 195) | def __call__(
  class TeaCache (line 288) | class TeaCache:
    method __init__ (line 289) | def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
    method check (line 309) | def check(self, dit: WanModel, x, t_mod):
    method store (line 331) | def store(self, hidden_states):
    method update (line 335) | def update(self, hidden_states):
  function model_fn_wan_video (line 341) | def model_fn_wan_video(

FILE: diffsynth/pipelines/wan_video_syncammaster.py
  class WanVideoSynCamMasterPipeline (line 23) | class WanVideoSynCamMasterPipeline(BasePipeline):
    method __init__ (line 25) | def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer...
    method enable_vram_management (line 38) | def enable_vram_management(self, num_persistent_param_in_dit=None):
    method fetch_models (line 126) | def fetch_models(self, model_manager: ModelManager):
    method from_model_manager (line 138) | def from_model_manager(model_manager: ModelManager, torch_dtype=None, ...
    method denoising_model (line 146) | def denoising_model(self):
    method encode_prompt (line 150) | def encode_prompt(self, prompt, positive=True):
    method encode_image (line 155) | def encode_image(self, image, num_frames, height, width):
    method tensor2video (line 173) | def tensor2video(self, frames):
    method prepare_extra_input (line 180) | def prepare_extra_input(self, latents=None):
    method encode_video (line 184) | def encode_video(self, input_video, tiled=True, tile_size=(34, 34), ti...
    method decode_video (line 189) | def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_s...
    method __call__ (line 195) | def __call__(
  class TeaCache (line 297) | class TeaCache:
    method __init__ (line 298) | def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
    method check (line 318) | def check(self, dit: WanModel, x, t_mod):
    method store (line 340) | def store(self, hidden_states):
    method update (line 344) | def update(self, hidden_states):
  function model_fn_wan_video (line 350) | def model_fn_wan_video(

FILE: diffsynth/processors/FastBlend.py
  class FastBlendSmoother (line 10) | class FastBlendSmoother(VideoProcessor):
    method __init__ (line 11) | def __init__(
    method from_model_manager (line 30) | def from_model_manager(model_manager, **kwargs):
    method inference_fast (line 34) | def inference_fast(self, frames_guide, frames_style):
    method inference_balanced (line 61) | def inference_balanced(self, frames_guide, frames_style):
    method inference_accurate (line 98) | def inference_accurate(self, frames_guide, frames_style):
    method release_vram (line 124) | def release_vram(self):
    method __call__ (line 130) | def __call__(self, rendered_frames, original_frames=None, **kwargs):

FILE: diffsynth/processors/PILEditor.py
  class ContrastEditor (line 5) | class ContrastEditor(VideoProcessor):
    method __init__ (line 6) | def __init__(self, rate=1.5):
    method from_model_manager (line 10) | def from_model_manager(model_manager, **kwargs):
    method __call__ (line 13) | def __call__(self, rendered_frames, **kwargs):
  class SharpnessEditor (line 18) | class SharpnessEditor(VideoProcessor):
    method __init__ (line 19) | def __init__(self, rate=1.5):
    method from_model_manager (line 23) | def from_model_manager(model_manager, **kwargs):
    method __call__ (line 26) | def __call__(self, rendered_frames, **kwargs):

FILE: diffsynth/processors/RIFE.py
  class RIFESmoother (line 7) | class RIFESmoother(VideoProcessor):
    method __init__ (line 8) | def __init__(self, model, device="cuda", scale=1.0, batch_size=4, inte...
    method from_model_manager (line 21) | def from_model_manager(model_manager, **kwargs):
    method process_image (line 24) | def process_image(self, image):
    method process_images (line 33) | def process_images(self, images):
    method decode_images (line 38) | def decode_images(self, images):
    method process_tensors (line 43) | def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
    method __call__ (line 55) | def __call__(self, rendered_frames, **kwargs):

FILE: diffsynth/processors/base.py
  class VideoProcessor (line 1) | class VideoProcessor:
    method __init__ (line 2) | def __init__(self):
    method __call__ (line 5) | def __call__(self):

FILE: diffsynth/processors/sequencial_processor.py
  class AutoVideoProcessor (line 4) | class AutoVideoProcessor(VideoProcessor):
    method __init__ (line 5) | def __init__(self):
    method from_model_manager (line 9) | def from_model_manager(model_manager, processor_type, **kwargs):
  class SequencialProcessor (line 26) | class SequencialProcessor(VideoProcessor):
    method __init__ (line 27) | def __init__(self, processors=[]):
    method from_model_manager (line 31) | def from_model_manager(model_manager, configs):
    method __call__ (line 38) | def __call__(self, rendered_frames, **kwargs):

FILE: diffsynth/prompters/base_prompter.py
  function tokenize_long_prompt (line 6) | def tokenize_long_prompt(tokenizer, prompt, max_length=None):
  class BasePrompter (line 39) | class BasePrompter:
    method __init__ (line 40) | def __init__(self):
    method load_prompt_refiners (line 45) | def load_prompt_refiners(self, model_manager: ModelManager, refiner_cl...
    method load_prompt_extenders (line 50) | def load_prompt_extenders(self,model_manager:ModelManager,extender_cla...
    method process_prompt (line 57) | def process_prompt(self, prompt, positive=True):
    method extend_prompt (line 66) | def extend_prompt(self, prompt:str, positive=True):

FILE: diffsynth/prompters/cog_prompter.py
  class CogPrompter (line 7) | class CogPrompter(BasePrompter):
    method __init__ (line 8) | def __init__(
    method fetch_models (line 20) | def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
    method encode_prompt_using_t5 (line 24) | def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_...
    method encode_prompt (line 38) | def encode_prompt(

FILE: diffsynth/prompters/flux_prompter.py
  class FluxPrompter (line 8) | class FluxPrompter(BasePrompter):
    method __init__ (line 9) | def __init__(
    method fetch_models (line 27) | def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_en...
    method encode_prompt_using_clip (line 32) | def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, ma...
    method encode_prompt_using_t5 (line 44) | def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_...
    method encode_prompt (line 56) | def encode_prompt(

FILE: diffsynth/prompters/hunyuan_dit_prompter.py
  class HunyuanDiTPrompter (line 8) | class HunyuanDiTPrompter(BasePrompter):
    method __init__ (line 9) | def __init__(
    method fetch_models (line 29) | def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None,...
    method encode_prompt_using_signle_model (line 34) | def encode_prompt_using_signle_model(self, prompt, text_encoder, token...
    method encode_prompt (line 53) | def encode_prompt(

FILE: diffsynth/prompters/hunyuan_video_prompter.py
  class HunyuanVideoPrompter (line 70) | class HunyuanVideoPrompter(BasePrompter):
    method __init__ (line 72) | def __init__(
    method fetch_models (line 94) | def fetch_models(self,
    method apply_text_to_template (line 109) | def apply_text_to_template(self, text, template):
    method encode_prompt_using_clip (line 119) | def encode_prompt_using_clip(self, prompt, max_length, device):
    method encode_prompt_using_llm (line 132) | def encode_prompt_using_llm(self,
    method encode_prompt_using_mllm (line 156) | def encode_prompt_using_mllm(self,
    method encode_prompt (line 236) | def encode_prompt(self,

FILE: diffsynth/prompters/kolors_prompter.py
  class SPTokenizer (line 12) | class SPTokenizer:
    method __init__ (line 13) | def __init__(self, model_path: str):
    method tokenize (line 35) | def tokenize(self, s: str, encode_special_tokens=False):
    method encode (line 50) | def encode(self, s: str, bos: bool = False, eos: bool = False) -> List...
    method decode (line 59) | def decode(self, t: List[int]) -> str:
    method decode_tokens (line 73) | def decode_tokens(self, tokens: List[str]) -> str:
    method convert_token_to_id (line 77) | def convert_token_to_id(self, token):
    method convert_id_to_token (line 83) | def convert_id_to_token(self, index):
  class ChatGLMTokenizer (line 93) | class ChatGLMTokenizer(PreTrainedTokenizer):
    method __init__ (line 98) | def __init__(self, vocab_file, padding_side="left", clean_up_tokenizat...
    method get_command (line 114) | def get_command(self, token):
    method unk_token (line 121) | def unk_token(self) -> str:
    method pad_token (line 125) | def pad_token(self) -> str:
    method pad_token_id (line 129) | def pad_token_id(self):
    method eos_token (line 133) | def eos_token(self) -> str:
    method eos_token_id (line 137) | def eos_token_id(self):
    method vocab_size (line 141) | def vocab_size(self):
    method get_vocab (line 144) | def get_vocab(self):
    method _tokenize (line 150) | def _tokenize(self, text, **kwargs):
    method _convert_token_to_id (line 153) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 157) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 161) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method save_vocabulary (line 164) | def save_vocabulary(self, save_directory, filename_prefix=None):
    method get_prefix_tokens (line 192) | def get_prefix_tokens(self):
    method build_single_message (line 196) | def build_single_message(self, role, metadata, message):
    method build_chat_input (line 203) | def build_chat_input(self, query, history=None, role="user"):
    method build_inputs_with_special_tokens (line 216) | def build_inputs_with_special_tokens(
    method _pad (line 241) | def _pad(
  class KolorsPrompter (line 307) | class KolorsPrompter(BasePrompter):
    method __init__ (line 308) | def __init__(
    method fetch_models (line 320) | def fetch_models(self, text_encoder: ChatGLMModel = None):
    method encode_prompt_using_ChatGLM (line 324) | def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer,...
    method encode_prompt (line 343) | def encode_prompt(

FILE: diffsynth/prompters/omnigen_prompter.py
  function crop_arr (line 14) | def crop_arr(pil_image, max_image_size):
  class OmniGenPrompter (line 44) | class OmniGenPrompter:
    method __init__ (line 45) | def __init__(self,
    method from_pretrained (line 61) | def from_pretrained(cls, model_name):
    method process_image (line 72) | def process_image(self, image):
    method process_multi_modal_prompt (line 75) | def process_multi_modal_prompt(self, text, input_images):
    method add_prefix_instruction (line 112) | def add_prefix_instruction(self, prompt):
    method __call__ (line 121) | def __call__(self,
  class OmniGenCollator (line 172) | class OmniGenCollator:
    method __init__ (line 173) | def __init__(self, pad_token_id=2, hidden_size=3072):
    method create_position (line 177) | def create_position(self, attention_mask, num_tokens_for_output_images):
    method create_mask (line 187) | def create_mask(self, attention_mask, num_tokens_for_output_images):
    method adjust_attention_for_input_images (line 226) | def adjust_attention_for_input_images(self, attention_mask, image_sizes):
    method pad_input_ids (line 233) | def pad_input_ids(self, input_ids, image_sizes):
    method process_mllm_input (line 259) | def process_mllm_input(self, mllm_inputs, target_img_size):
    method __call__ (line 287) | def __call__(self, features):
  class OmniGenSeparateCollator (line 314) | class OmniGenSeparateCollator(OmniGenCollator):
    method __call__ (line 315) | def __call__(self, features):

FILE: diffsynth/prompters/omost.py
  function safe_str (line 95) | def safe_str(x):
  function closest_name (line 98) | def closest_name(input_str, options):
  class Canvas (line 110) | class Canvas:
    method from_bot_response (line 112) | def from_bot_response(response: str):
    method __init__ (line 124) | def __init__(self):
    method set_global_description (line 132) | def set_global_description(self, description: str, detailed_descriptio...
    method add_local_description (line 153) | def add_local_description(self, location: str, offset: str, area: str,...
    method process (line 198) | def process(self):
  class OmostPromter (line 235) | class OmostPromter(torch.nn.Module):
    method __init__ (line 237) | def __init__(self,model = None,tokenizer = None, template = "",device=...
    method from_model_manager (line 259) | def from_model_manager(model_manager: ModelManager):
    method __call__ (line 270) | def __call__(self,prompt_dict:dict):

FILE: diffsynth/prompters/prompt_refiners.py
  class BeautifulPrompt (line 6) | class BeautifulPrompt(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, tokenizer_path=None, model=None, template=""):
    method from_model_manager (line 15) | def from_model_manager(model_manager: ModelManager):
    method __call__ (line 32) | def __call__(self, raw_prompt, positive=True, **kwargs):
  class QwenPrompt (line 57) | class QwenPrompt(torch.nn.Module):
    method __init__ (line 60) | def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
    method from_model_manager (line 68) | def from_model_manager(model_nameger: ModelManager):
    method __call__ (line 79) | def __call__(self, raw_prompt, positive=True, **kwargs):
  class Translator (line 111) | class Translator(torch.nn.Module):
    method __init__ (line 112) | def __init__(self, tokenizer_path=None, model=None):
    method from_model_manager (line 119) | def from_model_manager(model_manager: ModelManager):
    method __call__ (line 125) | def __call__(self, prompt, **kwargs):

FILE: diffsynth/prompters/sd3_prompter.py
  class SD3Prompter (line 8) | class SD3Prompter(BasePrompter):
    method __init__ (line 9) | def __init__(
    method fetch_models (line 33) | def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_en...
    method encode_prompt_using_clip (line 39) | def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, ma...
    method encode_prompt_using_t5 (line 51) | def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_...
    method encode_prompt (line 66) | def encode_prompt(

FILE: diffsynth/prompters/sd_prompter.py
  class SDPrompter (line 9) | class SDPrompter(BasePrompter):
    method __init__ (line 10) | def __init__(self, tokenizer_path=None):
    method fetch_models (line 21) | def fetch_models(self, text_encoder: SDTextEncoder = None):
    method add_textual_inversions_to_model (line 25) | def add_textual_inversions_to_model(self, textual_inversion_dict, text...
    method add_textual_inversions_to_tokenizer (line 39) | def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, ...
    method load_textual_inversions (line 48) | def load_textual_inversions(self, model_paths):
    method encode_prompt (line 63) | def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=T...

FILE: diffsynth/prompters/sdxl_prompter.py
  class SDXLPrompter (line 9) | class SDXLPrompter(BasePrompter):
    method __init__ (line 10) | def __init__(
    method fetch_models (line 28) | def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_enco...
    method encode_prompt (line 33) | def encode_prompt(

FILE: diffsynth/prompters/stepvideo_prompter.py
  class StepVideoPrompter (line 8) | class StepVideoPrompter(BasePrompter):
    method __init__ (line 10) | def __init__(
    method fetch_models (line 21) | def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = Non...
    method encode_prompt_using_clip (line 25) | def encode_prompt_using_clip(self, prompt, max_length, device):
    method encode_prompt_using_llm (line 40) | def encode_prompt_using_llm(self, prompt, max_length, device):
    method encode_prompt (line 44) | def encode_prompt(self,

FILE: diffsynth/prompters/wan_prompter.py
  function basic_clean (line 11) | def basic_clean(text):
  function whitespace_clean (line 17) | def whitespace_clean(text):
  function canonicalize (line 23) | def canonicalize(text, keep_punctuation_exact_string=None):
  class HuggingfaceTokenizer (line 36) | class HuggingfaceTokenizer:
    method __init__ (line 38) | def __init__(self, name, seq_len=None, clean=None, **kwargs):
    method __call__ (line 48) | def __call__(self, sequence, **kwargs):
    method _clean (line 74) | def _clean(self, text):
  class WanPrompter (line 84) | class WanPrompter(BasePrompter):
    method __init__ (line 86) | def __init__(self, tokenizer_path=None, text_len=512):
    method fetch_tokenizer (line 92) | def fetch_tokenizer(self, tokenizer_path=None):
    method fetch_models (line 96) | def fetch_models(self, text_encoder: WanTextEncoder = None):
    method encode_prompt (line 99) | def encode_prompt(self, prompt, positive=True, device="cuda"):

FILE: diffsynth/schedulers/continuous_ode.py
  class ContinuousODEScheduler (line 4) | class ContinuousODEScheduler():
    method __init__ (line 6) | def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min...
    method set_timesteps (line 13) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
    method step (line 21) | def step(self, model_output, timestep, sample, to_final=False):
    method return_to_timestep (line 36) | def return_to_timestep(self, timestep, sample, sample_stablized):
    method add_noise (line 41) | def add_noise(self, original_samples, noise, timestep):
    method training_target (line 48) | def training_target(self, sample, noise, timestep):
    method training_weight (line 55) | def training_weight(self, timestep):

FILE: diffsynth/schedulers/ddim.py
  class EnhancedDDIMScheduler (line 4) | class EnhancedDDIMScheduler():
    method __init__ (line 6) | def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_...
    method rescale_zero_terminal_snr (line 22) | def rescale_zero_terminal_snr(self, alphas_cumprod):
    method set_timesteps (line 41) | def set_timesteps(self, num_inference_steps, denoising_strength=1.0, *...
    method denoise (line 53) | def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
    method step (line 67) | def step(self, model_output, timestep, sample, to_final=False):
    method return_to_timestep (line 81) | def return_to_timestep(self, timestep, sample, sample_stablized):
    method add_noise (line 87) | def add_noise(self, original_samples, noise, timestep):
    method training_target (line 94) | def training_target(self, sample, noise, timestep):
    method training_weight (line 104) | def training_weight(self, timestep):

FILE: diffsynth/schedulers/flow_match.py
  class FlowMatchScheduler (line 5) | class FlowMatchScheduler():
    method __init__ (line 7) | def __init__(self, num_inference_steps=100, num_train_timesteps=1000, ...
    method set_timesteps (line 18) | def set_timesteps(self, num_inference_steps=100, denoising_strength=1....
    method step (line 40) | def step(self, model_output, timestep, sample, to_final=False, **kwargs):
    method return_to_timestep (line 53) | def return_to_timestep(self, timestep, sample, sample_stablized):
    method add_noise (line 62) | def add_noise(self, original_samples, noise, timestep):
    method training_target (line 71) | def training_target(self, sample, noise, timestep):
    method training_weight (line 76) | def training_weight(self, timestep):

FILE: diffsynth/trainers/text_to_image.py
  class LightningModelForT2ILoRA (line 10) | class LightningModelForT2ILoRA(pl.LightningModule):
    method __init__ (line 11) | def __init__(
    method load_models (line 25) | def load_models(self):
    method freeze_parameters (line 30) | def freeze_parameters(self):
    method add_lora_to_model (line 37) | def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_tar...
    method training_step (line 67) | def training_step(self, batch, batch_idx):
    method configure_optimizers (line 98) | def configure_optimizers(self):
    method on_save_checkpoint (line 104) | def on_save_checkpoint(self, checkpoint):
  function add_general_parsers (line 119) | def add_general_parsers(parser):
  function launch_training_task (line 267) | def launch_training_task(model, args):

FILE: diffsynth/vram_management/layers.py
  function cast_to (line 5) | def cast_to(weight, dtype, device):
  class AutoWrappedModule (line 11) | class AutoWrappedModule(torch.nn.Module):
    method __init__ (line 12) | def __init__(self, module: torch.nn.Module, offload_dtype, offload_dev...
    method offload (line 23) | def offload(self):
    method onload (line 28) | def onload(self):
    method forward (line 33) | def forward(self, *args, **kwargs):
  class AutoWrappedLinear (line 41) | class AutoWrappedLinear(torch.nn.Linear):
    method __init__ (line 42) | def __init__(self, module: torch.nn.Linear, offload_dtype, offload_dev...
    method offload (line 55) | def offload(self):
    method onload (line 60) | def onload(self):
    method forward (line 65) | def forward(self, x, *args, **kwargs):
  function enable_vram_management_recursively (line 74) | def enable_vram_management_recursively(model: torch.nn.Module, module_ma...
  function enable_vram_management (line 92) | def enable_vram_management(model: torch.nn.Module, module_map: dict, mod...

FILE: inference_syncammaster.py
  class Camera (line 16) | class Camera(object):
    method __init__ (line 17) | def __init__(self, c2w):
  class TextCameraDataset (line 22) | class TextCameraDataset(torch.utils.data.Dataset):
    method __init__ (line 23) | def __init__(self, base_path, metadata_path, args, max_num_frames=81, ...
    method crop_and_resize (line 44) | def crop_and_resize(self, image):
    method load_frames_using_imageio (line 55) | def load_frames_using_imageio(self, file_path, max_num_frames, start_f...
    method is_image (line 82) | def is_image(self, file_path):
    method load_video (line 89) | def load_video(self, file_path):
    method parse_matrix (line 95) | def parse_matrix(self, matrix_str):
    method get_relative_pose (line 104) | def get_relative_pose(self, cam_params):
    method __getitem__ (line 121) | def __getitem__(self, data_id):
    method __len__ (line 163) | def __len__(self):
  function parse_args (line 166) | def parse_args():

FILE: train_syncammaster.py
  class TextVideoDataset (line 21) | class TextVideoDataset(torch.utils.data.Dataset):
    method __init__ (line 22) | def __init__(self, base_path, metadata_path, max_num_frames=81, frame_...
    method crop_and_resize (line 42) | def crop_and_resize(self, image):
    method load_frames_using_imageio (line 53) | def load_frames_using_imageio(self, file_path, max_num_frames, start_f...
    method load_video (line 80) | def load_video(self, file_path):
    method is_image (line 86) | def is_image(self, file_path):
    method load_image (line 93) | def load_image(self, file_path):
    method __getitem__ (line 102) | def __getitem__(self, data_id):
    method __len__ (line 124) | def __len__(self):
  class LightningModelForDataProcess (line 129) | class LightningModelForDataProcess(pl.LightningModule):
    method __init__ (line 130) | def __init__(self, text_encoder_path, vae_path, image_encoder_path=Non...
    method test_step (line 141) | def test_step(self, batch, batch_idx):
  class Camera (line 165) | class Camera(object):
    method __init__ (line 166) | def __init__(self, c2w):
  class TensorDataset (line 173) | class TensorDataset(torch.utils.data.Dataset):
    method __init__ (line 174) | def __init__(self, base_path, metadata_path, steps_per_epoch):
    method parse_matrix (line 184) | def parse_matrix(self, matrix_str):
    method get_relative_pose (line 193) | def get_relative_pose(self, cam_params):
    method __getitem__ (line 208) | def __getitem__(self, index):
    method __len__ (line 266) | def __len__(self):
  class LightningModelForTrain (line 271) | class LightningModelForTrain(pl.LightningModule):
    method __init__ (line 272) | def __init__(
    method freeze_parameters (line 327) | def freeze_parameters(self):
    method training_step (line 334) | def training_step(self, batch, batch_idx):
    method configure_optimizers (line 373) | def configure_optimizers(self):
    method on_save_checkpoint (line 379) | def on_save_checkpoint(self, checkpoint):
  function parse_args (line 393) | def parse_args():
  function data_process (line 561) | def data_process(args):
  function train (line 594) | def train(args):

FILE: vis_cam.py
  class CameraPoseVisualizer (line 9) | class CameraPoseVisualizer:
    method __init__ (line 10) | def __init__(self, xlim, ylim, zlim):
    method extrinsic2pyramid (line 23) | def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16,...
    method customize_legend (line 41) | def customize_legend(self, list_label):
    method colorbar (line 49) | def colorbar(self, max_frame_length):
    method show (line 54) | def show(self):
  function get_args (line 60) | def get_args():
  function get_c2w (line 74) | def get_c2w(w2cs, transform_matrix, relative_c2w=True):
  function parse_matrix (line 89) | def parse_matrix(matrix_str):
Copy disabled (too large) Download .json
Condensed preview — 222 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (19,142K chars).
[
  {
    "path": ".gitignore",
    "chars": 26,
    "preview": "*__pycache__\n*.ckpt\nWan-AI"
  },
  {
    "path": "README.md",
    "chars": 12225,
    "preview": "# SynCamMaster: Synchronizing Multi-Camera Video Generation from Diverse Viewpoints\n\n<div align=\"center\">\n<div align=\"ce"
  },
  {
    "path": "diffsynth/__init__.py",
    "chars": 145,
    "preview": "from .data import *\nfrom .models import *\nfrom .prompters import *\nfrom .schedulers import *\nfrom .pipelines import *\nfr"
  },
  {
    "path": "diffsynth/configs/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "diffsynth/configs/model_config.py",
    "chars": 50114,
    "preview": "from typing_extensions import Literal, TypeAlias\n\nfrom ..models.sd_text_encoder import SDTextEncoder\nfrom ..models.sd_un"
  },
  {
    "path": "diffsynth/controlnets/__init__.py",
    "chars": 152,
    "preview": "from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager\nfr"
  },
  {
    "path": "diffsynth/controlnets/controlnet_unit.py",
    "chars": 3628,
    "preview": "import torch\nimport numpy as np\nfrom .processors import Processor_id\n\n\nclass ControlNetConfigUnit:\n    def __init__(self"
  },
  {
    "path": "diffsynth/controlnets/processors.py",
    "chars": 2952,
    "preview": "from typing_extensions import Literal, TypeAlias\n\n\nProcessor_id: TypeAlias = Literal[\n    \"canny\", \"depth\", \"softedge\", "
  },
  {
    "path": "diffsynth/data/__init__.py",
    "chars": 54,
    "preview": "from .video import VideoData, save_video, save_frames\n"
  },
  {
    "path": "diffsynth/data/simple_text_image.py",
    "chars": 1775,
    "preview": "import torch, os, torchvision\nfrom torchvision import transforms\nimport pandas as pd\nfrom PIL import Image\n\n\n\nclass Text"
  },
  {
    "path": "diffsynth/data/video.py",
    "chars": 4703,
    "preview": "import imageio, os\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\n\n\nclass LowMemoryVideo:\n    def __init"
  },
  {
    "path": "diffsynth/extensions/ESRGAN/__init__.py",
    "chars": 5096,
    "preview": "import torch\nfrom einops import repeat\nfrom PIL import Image\nimport numpy as np\n\n\nclass ResidualDenseBlock(torch.nn.Modu"
  },
  {
    "path": "diffsynth/extensions/FastBlend/__init__.py",
    "chars": 2823,
    "preview": "from .runners.fast import TableManager, PyramidPatchMatcher\nfrom PIL import Image\nimport numpy as np\nimport cupy as cp\n\n"
  },
  {
    "path": "diffsynth/extensions/FastBlend/api.py",
    "chars": 20207,
    "preview": "from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeS"
  },
  {
    "path": "diffsynth/extensions/FastBlend/cupy_kernels.py",
    "chars": 4430,
    "preview": "import cupy as cp\n\nremapping_kernel = cp.RawKernel(r'''\nextern \"C\" __global__\nvoid remap(\n    const int height,\n    cons"
  },
  {
    "path": "diffsynth/extensions/FastBlend/data.py",
    "chars": 4094,
    "preview": "import imageio, os\nimport numpy as np\nfrom PIL import Image\n\n\ndef read_video(file_name):\n    reader = imageio.get_reader"
  },
  {
    "path": "diffsynth/extensions/FastBlend/patch_match.py",
    "chars": 13837,
    "preview": "from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel\nimport numpy as np\nimport cu"
  },
  {
    "path": "diffsynth/extensions/FastBlend/runners/__init__.py",
    "chars": 202,
    "preview": "from .accurate import AccurateModeRunner\nfrom .fast import FastModeRunner\nfrom .balanced import BalancedModeRunner\nfrom "
  },
  {
    "path": "diffsynth/extensions/FastBlend/runners/accurate.py",
    "chars": 1531,
    "preview": "from ..patch_match import PyramidPatchMatcher\nimport os\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\n\n"
  },
  {
    "path": "diffsynth/extensions/FastBlend/runners/balanced.py",
    "chars": 2142,
    "preview": "from ..patch_match import PyramidPatchMatcher\nimport os\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\n\n"
  },
  {
    "path": "diffsynth/extensions/FastBlend/runners/fast.py",
    "chars": 6648,
    "preview": "from ..patch_match import PyramidPatchMatcher\nimport functools, os\nimport numpy as np\nfrom PIL import Image\nfrom tqdm im"
  },
  {
    "path": "diffsynth/extensions/FastBlend/runners/interpolation.py",
    "chars": 5286,
    "preview": "from ..patch_match import PyramidPatchMatcher\nimport os\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\n\n"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py",
    "chars": 29,
    "preview": "from .blip_pretrain import *\n"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/BLIP/blip.py",
    "chars": 3425,
    "preview": "'''\n * Adapted from BLIP (https://github.com/salesforce/BLIP)\n'''\n\nimport warnings\nwarnings.filterwarnings(\"ignore\")\n\nim"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py",
    "chars": 1555,
    "preview": "'''\n * Adapted from BLIP (https://github.com/salesforce/BLIP)\n'''\n\nimport transformers\ntransformers.logging.set_verbosit"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/BLIP/med.py",
    "chars": 41498,
    "preview": "'''\n * Adapted from BLIP (https://github.com/salesforce/BLIP)\n * Based on huggingface code base\n * https://github.com/hu"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/BLIP/vit.py",
    "chars": 14069,
    "preview": "'''\n * Adapted from BLIP (https://github.com/salesforce/BLIP)\n * Based on timm code base\n * https://github.com/rwightman"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/__init__.py",
    "chars": 6126,
    "preview": "from modelscope import snapshot_download\nfrom typing_extensions import Literal, TypeAlias\nimport os\nfrom diffsynth.exten"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/aesthetic.py",
    "chars": 5452,
    "preview": "from typing import List, Optional\nfrom PIL import Image\nimport torch\nfrom transformers import AutoProcessor, AutoModel\nf"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/clip.py",
    "chars": 3849,
    "preview": "from typing import List, Union\nfrom PIL import Image\nimport torch\nfrom .open_clip import create_model_and_transforms, ge"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/config.py",
    "chars": 1060,
    "preview": "import os\n\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nproject_root = os.path.abspath(os.path.join(current_"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/hps.py",
    "chars": 4777,
    "preview": "from typing import List, Union\nfrom PIL import Image\nimport torch\nfrom .open_clip import create_model_and_transforms, ge"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/imagereward.py",
    "chars": 8729,
    "preview": "import os\nimport torch\nfrom PIL import Image\nfrom typing import List, Union\nfrom torchvision.transforms import Compose, "
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/mps.py",
    "chars": 5952,
    "preview": "import numpy as np\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\nfrom tqdm.auto import tqdm\nfrom transformer"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py",
    "chars": 985,
    "preview": "from .coca_model import CoCa\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\nfrom .factory import create_"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py",
    "chars": 17439,
    "preview": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nf"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/constants.py",
    "chars": 116,
    "preview": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\n"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/factory.py",
    "chars": 15868,
    "preview": "import json\nimport logging\nimport os\nimport pathlib\nimport re\nfrom copy import deepcopy\nfrom pathlib import Path\n# from "
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py",
    "chars": 1675,
    "preview": "# HF architecture dict:\narch_dict = {\n    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n    \"robe"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py",
    "chars": 6298,
    "preview": "\"\"\" huggingface model adapter\n\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for u"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/loss.py",
    "chars": 10711,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torch.nn.utils.rnn import pad_sequence\n\ntry"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/model.py",
    "chars": 18298,
    "preview": "\"\"\" CLIP Model\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nfrom"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json",
    "chars": 324,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py",
    "chars": 7026,
    "preview": "from collections import OrderedDict\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .utils"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/openai.py",
    "chars": 5446,
    "preview": "\"\"\" OpenAI pretrained model functions\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py",
    "chars": 14144,
    "preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom functools import partial\nfrom typing import Dict, Union\n\nfro"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py",
    "chars": 7660,
    "preview": "import argparse\nimport json\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\nfrom typing import Optional"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py",
    "chars": 5077,
    "preview": "\"\"\" timm model adapter\n\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower "
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py",
    "chars": 7541,
    "preview": "\"\"\" CLIP tokenizer\n\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\ni"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/transform.py",
    "chars": 7874,
    "preview": "import warnings\nfrom dataclasses import dataclass, asdict\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py",
    "chars": 27164,
    "preview": "from collections import OrderedDict\nimport math\nfrom typing import Callable, Optional, Sequence, Tuple\n\nimport torch\nfro"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/utils.py",
    "chars": 2223,
    "preview": "from itertools import repeat\nimport collections.abc\n\nfrom torch import nn as nn\nfrom torchvision.ops.misc import FrozenB"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/open_clip/version.py",
    "chars": 23,
    "preview": "__version__ = '2.16.0'\n"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/pickscore.py",
    "chars": 4513,
    "preview": "import torch\nfrom PIL import Image\nfrom transformers import AutoProcessor, AutoModel\nfrom typing import List, Union\nimpo"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/trainer/__init__.py",
    "chars": 21,
    "preview": "from .models import *"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py",
    "chars": 81,
    "preview": "from .base_model import *\nfrom .clip_model import *\nfrom .cross_modeling import *"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py",
    "chars": 80,
    "preview": "from dataclasses import dataclass\n\n\n\n@dataclass\nclass BaseModelConfig:\n    pass\n"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py",
    "chars": 5453,
    "preview": "from dataclasses import dataclass\nfrom transformers import CLIPModel as HFCLIPModel\nfrom transformers import AutoTokeniz"
  },
  {
    "path": "diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py",
    "chars": 7956,
    "preview": "import torch\nfrom torch import einsum, nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\n# helper"
  },
  {
    "path": "diffsynth/extensions/RIFE/__init__.py",
    "chars": 10736,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\n\ndef warp(t"
  },
  {
    "path": "diffsynth/extensions/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "diffsynth/models/__init__.py",
    "chars": 29,
    "preview": "from .model_manager import *\n"
  },
  {
    "path": "diffsynth/models/attention.py",
    "chars": 3957,
    "preview": "import torch\nfrom einops import rearrange\n\n\ndef low_version_attention(query, key, value, attn_bias=None):\n    scale = 1 "
  },
  {
    "path": "diffsynth/models/cog_dit.py",
    "chars": 17475,
    "preview": "import torch\nfrom einops import rearrange, repeat\nfrom .sd3_dit import TimestepEmbeddings\nfrom .attention import Attenti"
  },
  {
    "path": "diffsynth/models/cog_vae.py",
    "chars": 23126,
    "preview": "import torch\nfrom einops import rearrange, repeat\nfrom .tiler import TileWorker2Dto3D\n\n\n\nclass Downsample3D(torch.nn.Mod"
  },
  {
    "path": "diffsynth/models/downloader.py",
    "chars": 4763,
    "preview": "from huggingface_hub import hf_hub_download\nfrom modelscope import snapshot_download\nimport os, shutil\nfrom typing_exten"
  },
  {
    "path": "diffsynth/models/flux_controlnet.py",
    "chars": 16429,
    "preview": "import torch\nfrom einops import rearrange, repeat\nfrom .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTran"
  },
  {
    "path": "diffsynth/models/flux_dit.py",
    "chars": 33956,
    "preview": "import torch\nfrom .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm\nfrom einops import rearrange\nfrom .tiler imp"
  },
  {
    "path": "diffsynth/models/flux_ipadapter.py",
    "chars": 3697,
    "preview": "from .svd_image_encoder import SVDImageEncoder\nfrom .sd3_dit import RMSNorm\nfrom transformers import CLIPImageProcessor\n"
  },
  {
    "path": "diffsynth/models/flux_text_encoder.py",
    "chars": 775,
    "preview": "import torch\nfrom transformers import T5EncoderModel, T5Config\nfrom .sd_text_encoder import SDTextEncoder\n\n\n\nclass FluxT"
  },
  {
    "path": "diffsynth/models/flux_vae.py",
    "chars": 19658,
    "preview": "from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter\nfrom .sd3_vae_decoder import SD3VAEDecoder, S"
  },
  {
    "path": "diffsynth/models/hunyuan_dit.py",
    "chars": 20154,
    "preview": "from .attention import Attention\nfrom einops import repeat, rearrange\nimport math\nimport torch\n\n\nclass HunyuanDiTRotaryE"
  },
  {
    "path": "diffsynth/models/hunyuan_dit_text_encoder.py",
    "chars": 5565,
    "preview": "from transformers import BertModel, BertConfig, T5EncoderModel, T5Config\nimport torch\n\n\n\nclass HunyuanDiTCLIPTextEncoder"
  },
  {
    "path": "diffsynth/models/hunyuan_video_dit.py",
    "chars": 42030,
    "preview": "import torch\nfrom .sd3_dit import TimestepEmbeddings, RMSNorm\nfrom .utils import init_weights_on_device\nfrom einops impo"
  },
  {
    "path": "diffsynth/models/hunyuan_video_text_encoder.py",
    "chars": 2715,
    "preview": "from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration\nfrom copy import deepcopy\n"
  },
  {
    "path": "diffsynth/models/hunyuan_video_vae_decoder.py",
    "chars": 19320,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport numpy as np\nfrom "
  },
  {
    "path": "diffsynth/models/hunyuan_video_vae_encoder.py",
    "chars": 11066,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nimport numpy as "
  },
  {
    "path": "diffsynth/models/kolors_text_encoder.py",
    "chars": 72734,
    "preview": "\"\"\"\nThis model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.\nWe didn't modify this mod"
  },
  {
    "path": "diffsynth/models/lora.py",
    "chars": 18187,
    "preview": "import torch\nfrom .sd_unet import SDUNet\nfrom .sdxl_unet import SDXLUNet\nfrom .sd_text_encoder import SDTextEncoder\nfrom"
  },
  {
    "path": "diffsynth/models/model_manager.py",
    "chars": 20790,
    "preview": "import os, torch, json, importlib\nfrom typing import List\n\nfrom .downloader import download_models, download_customized_"
  },
  {
    "path": "diffsynth/models/omnigen.py",
    "chars": 33069,
    "preview": "# The code is revised from DiT\nimport os\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport math\nfrom safetens"
  },
  {
    "path": "diffsynth/models/sd3_dit.py",
    "chars": 27845,
    "preview": "import torch\nfrom einops import rearrange\nfrom .svd_unet import TemporalTimesteps\nfrom .tiler import TileWorker\n\n\n\nclass"
  },
  {
    "path": "diffsynth/models/sd3_text_encoder.py",
    "chars": 110544,
    "preview": "import torch\nfrom transformers import T5EncoderModel, T5Config\nfrom .sd_text_encoder import SDTextEncoder\nfrom .sdxl_tex"
  },
  {
    "path": "diffsynth/models/sd3_vae_decoder.py",
    "chars": 3006,
    "preview": "import torch\nfrom .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter\nfrom .sd_unet import ResnetBl"
  },
  {
    "path": "diffsynth/models/sd3_vae_encoder.py",
    "chars": 3599,
    "preview": "import torch\nfrom .sd_unet import ResnetBlock, DownSampler\nfrom .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderSt"
  },
  {
    "path": "diffsynth/models/sd_controlnet.py",
    "chars": 48266,
    "preview": "import torch\nfrom .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler\nfrom .tiler import Tile"
  },
  {
    "path": "diffsynth/models/sd_ipadapter.py",
    "chars": 2362,
    "preview": "from .svd_image_encoder import SVDImageEncoder\nfrom .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDX"
  },
  {
    "path": "diffsynth/models/sd_motion.py",
    "chars": 8323,
    "preview": "from .sd_unet import SDUNet, Attention, GEGLU\nimport torch\nfrom einops import rearrange, repeat\n\n\nclass TemporalTransfor"
  },
  {
    "path": "diffsynth/models/sd_text_encoder.py",
    "chars": 28750,
    "preview": "import torch\nfrom .attention import Attention\n\n\nclass CLIPEncoderLayer(torch.nn.Module):\n    def __init__(self, embed_di"
  },
  {
    "path": "diffsynth/models/sd_unet.py",
    "chars": 99140,
    "preview": "import torch, math\nfrom .attention import Attention\nfrom .tiler import TileWorker\n\n\nclass Timesteps(torch.nn.Module):\n  "
  },
  {
    "path": "diffsynth/models/sd_vae_decoder.py",
    "chars": 20845,
    "preview": "import torch\nfrom .attention import Attention\nfrom .sd_unet import ResnetBlock, UpSampler\nfrom .tiler import TileWorker\n"
  },
  {
    "path": "diffsynth/models/sd_vae_encoder.py",
    "chars": 17272,
    "preview": "import torch\nfrom .sd_unet import ResnetBlock, DownSampler\nfrom .sd_vae_decoder import VAEAttentionBlock\nfrom .tiler imp"
  },
  {
    "path": "diffsynth/models/sdxl_controlnet.py",
    "chars": 13942,
    "preview": "import torch\nfrom .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler\nfrom .sdxl_unet import "
  },
  {
    "path": "diffsynth/models/sdxl_ipadapter.py",
    "chars": 4918,
    "preview": "from .svd_image_encoder import SVDImageEncoder\nfrom transformers import CLIPImageProcessor\nimport torch\n\n\nclass IpAdapte"
  },
  {
    "path": "diffsynth/models/sdxl_motion.py",
    "chars": 4209,
    "preview": "from .sd_motion import TemporalBlock\nimport torch\n\n\n\nclass SDXLMotionModel(torch.nn.Module):\n    def __init__(self):\n   "
  },
  {
    "path": "diffsynth/models/sdxl_text_encoder.py",
    "chars": 79460,
    "preview": "import torch\nfrom .sd_text_encoder import CLIPEncoderLayer\n    \n\nclass SDXLTextEncoder(torch.nn.Module):\n    def __init_"
  },
  {
    "path": "diffsynth/models/sdxl_unet.py",
    "chars": 234786,
    "preview": "import torch\nfrom .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler\n\n\n"
  },
  {
    "path": "diffsynth/models/sdxl_vae_decoder.py",
    "chars": 762,
    "preview": "from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter\n\n\nclass SDXLVAEDecoder(SDVAEDecoder):\n    def _"
  },
  {
    "path": "diffsynth/models/sdxl_vae_encoder.py",
    "chars": 762,
    "preview": "from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder\n\n\nclass SDXLVAEEncoder(SDVAEEncoder):\n    def _"
  },
  {
    "path": "diffsynth/models/stepvideo_dit.py",
    "chars": 33689,
    "preview": "# Copyright 2025 StepFun Inc. All Rights Reserved.\n# \n# Permission is hereby granted, free of charge, to any person obta"
  },
  {
    "path": "diffsynth/models/stepvideo_text_encoder.py",
    "chars": 19326,
    "preview": "# Copyright 2025 StepFun Inc. All Rights Reserved.\n# \n# Permission is hereby granted, free of charge, to any person obta"
  },
  {
    "path": "diffsynth/models/stepvideo_vae.py",
    "chars": 42397,
    "preview": "# Copyright 2025 StepFun Inc. All Rights Reserved.\n# \n# Permission is hereby granted, free of charge, to any person obta"
  },
  {
    "path": "diffsynth/models/svd_image_encoder.py",
    "chars": 60869,
    "preview": "import torch\nfrom .sd_text_encoder import CLIPEncoderLayer\n\n\nclass CLIPVisionEmbeddings(torch.nn.Module):\n    def __init"
  },
  {
    "path": "diffsynth/models/svd_unet.py",
    "chars": 200698,
    "preview": "import torch, math\nfrom einops import rearrange, repeat\nfrom .sd_unet import Timesteps, PushBlock, PopBlock, Attention, "
  },
  {
    "path": "diffsynth/models/svd_vae_decoder.py",
    "chars": 40309,
    "preview": "import torch\nfrom .attention import Attention\nfrom .sd_unet import ResnetBlock, UpSampler\nfrom .tiler import TileWorker\n"
  },
  {
    "path": "diffsynth/models/svd_vae_encoder.py",
    "chars": 12377,
    "preview": "from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder\n\n\nclass SVDVAEEncoder(SDVAEEncoder):\n    def __"
  },
  {
    "path": "diffsynth/models/tiler.py",
    "chars": 9732,
    "preview": "import torch\nfrom einops import rearrange, repeat\n\n\nclass TileWorker:\n    def __init__(self):\n        pass\n\n\n    def mas"
  },
  {
    "path": "diffsynth/models/utils.py",
    "chars": 7093,
    "preview": "import torch, os\nfrom safetensors import safe_open\nfrom contextlib import contextmanager\nimport hashlib\n\n@contextmanager"
  },
  {
    "path": "diffsynth/models/wan_video_dit.py",
    "chars": 21173,
    "preview": "import re\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Tuple, Optio"
  },
  {
    "path": "diffsynth/models/wan_video_image_encoder.py",
    "chars": 28139,
    "preview": "\"\"\"\nConcise re-implementation of\n``https://github.com/openai/CLIP'' and\n``https://github.com/mlfoundations/open_clip''.\n"
  },
  {
    "path": "diffsynth/models/wan_video_text_encoder.py",
    "chars": 9131,
    "preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef fp16_clamp(x):\n    if x.dtype == t"
  },
  {
    "path": "diffsynth/models/wan_video_vae.py",
    "chars": 29483,
    "preview": "from einops import rearrange, repeat\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom tqdm impor"
  },
  {
    "path": "diffsynth/pipelines/__init__.py",
    "chars": 707,
    "preview": "from .sd_image import SDImagePipeline\nfrom .sd_video import SDVideoPipeline\nfrom .sdxl_image import SDXLImagePipeline\nfr"
  },
  {
    "path": "diffsynth/pipelines/base.py",
    "chars": 5770,
    "preview": "import torch\nimport numpy as np\nfrom PIL import Image\nfrom torchvision.transforms import GaussianBlur\n\n\n\nclass BasePipel"
  },
  {
    "path": "diffsynth/pipelines/cog_video.py",
    "chars": 5313,
    "preview": "from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder\nfrom ..prompters import CogPro"
  },
  {
    "path": "diffsynth/pipelines/dancer.py",
    "chars": 10179,
    "preview": "import torch\nfrom ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel\nfrom ..models.sd_unet import PushBloc"
  },
  {
    "path": "diffsynth/pipelines/flux_image.py",
    "chars": 29090,
    "preview": "from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAda"
  },
  {
    "path": "diffsynth/pipelines/hunyuan_image.py",
    "chars": 11919,
    "preview": "from ..models.hunyuan_dit import HunyuanDiT\nfrom ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hun"
  },
  {
    "path": "diffsynth/pipelines/hunyuan_video.py",
    "chars": 17733,
    "preview": "from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder\nfrom ..models.hunyuan"
  },
  {
    "path": "diffsynth/pipelines/omnigen_image.py",
    "chars": 13049,
    "preview": "from ..models.omnigen import OmniGenTransformer\nfrom ..models.sdxl_vae_encoder import SDXLVAEEncoder\nfrom ..models.sdxl_"
  },
  {
    "path": "diffsynth/pipelines/pipeline_runner.py",
    "chars": 5185,
    "preview": "import os, torch, json\nfrom .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit\nfrom ..processors.seque"
  },
  {
    "path": "diffsynth/pipelines/sd3_image.py",
    "chars": 6291,
    "preview": "from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncod"
  },
  {
    "path": "diffsynth/pipelines/sd_image.py",
    "chars": 8501,
    "preview": "from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder\nfrom ..m"
  },
  {
    "path": "diffsynth/pipelines/sd_video.py",
    "chars": 12000,
    "preview": "from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotio"
  },
  {
    "path": "diffsynth/pipelines/sdxl_image.py",
    "chars": 10335,
    "preview": "from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapt"
  },
  {
    "path": "diffsynth/pipelines/sdxl_video.py",
    "chars": 10534,
    "preview": "from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapt"
  },
  {
    "path": "diffsynth/pipelines/step_video.py",
    "chars": 8494,
    "preview": "from ..models import ModelManager\nfrom ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder\nfrom ..models."
  },
  {
    "path": "diffsynth/pipelines/svd_video.py",
    "chars": 11418,
    "preview": "from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder\nfrom ..schedulers import Conti"
  },
  {
    "path": "diffsynth/pipelines/wan_video.py",
    "chars": 16001,
    "preview": "from ..models import ModelManager\nfrom ..models.wan_video_dit import WanModel\nfrom ..models.wan_video_text_encoder impor"
  },
  {
    "path": "diffsynth/pipelines/wan_video_syncammaster.py",
    "chars": 16924,
    "preview": "from ..models import ModelManager\nfrom ..models.wan_video_dit import WanModel\nfrom ..models.wan_video_text_encoder impor"
  },
  {
    "path": "diffsynth/processors/FastBlend.py",
    "chars": 7027,
    "preview": "from PIL import Image\nimport cupy as cp\nimport numpy as np\nfrom tqdm import tqdm\nfrom ..extensions.FastBlend.patch_match"
  },
  {
    "path": "diffsynth/processors/PILEditor.py",
    "chars": 855,
    "preview": "from PIL import ImageEnhance\nfrom .base import VideoProcessor\n\n\nclass ContrastEditor(VideoProcessor):\n    def __init__(s"
  },
  {
    "path": "diffsynth/processors/RIFE.py",
    "chars": 3118,
    "preview": "import torch\nimport numpy as np\nfrom PIL import Image\nfrom .base import VideoProcessor\n\n\nclass RIFESmoother(VideoProcess"
  },
  {
    "path": "diffsynth/processors/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "diffsynth/processors/base.py",
    "chars": 118,
    "preview": "class VideoProcessor:\n    def __init__(self):\n        pass\n\n    def __call__(self):\n        raise NotImplementedError\n"
  },
  {
    "path": "diffsynth/processors/sequencial_processor.py",
    "chars": 1582,
    "preview": "from .base import VideoProcessor\n\n\nclass AutoVideoProcessor(VideoProcessor):\n    def __init__(self):\n        pass\n\n    @"
  },
  {
    "path": "diffsynth/prompters/__init__.py",
    "chars": 535,
    "preview": "from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt\nfrom .sd_prompter import SDPrompter\nfrom .sdxl_prom"
  },
  {
    "path": "diffsynth/prompters/base_prompter.py",
    "chars": 2234,
    "preview": "from ..models.model_manager import ModelManager\nimport torch\n\n\n\ndef tokenize_long_prompt(tokenizer, prompt, max_length=N"
  },
  {
    "path": "diffsynth/prompters/cog_prompter.py",
    "chars": 1474,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.flux_text_encoder import FluxTextEncoder2\nfrom transformers import"
  },
  {
    "path": "diffsynth/prompters/flux_prompter.py",
    "chars": 2703,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.flux_text_encoder import FluxTextEncoder2\nfrom ..models.sd3_text_e"
  },
  {
    "path": "diffsynth/prompters/hunyuan_dit_prompter.py",
    "chars": 2743,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.model_manager import ModelManager\nfrom ..models import HunyuanDiTC"
  },
  {
    "path": "diffsynth/prompters/hunyuan_video_prompter.py",
    "chars": 13574,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.sd3_text_encoder import SD3TextEncoder1\nfrom ..models.hunyuan_vide"
  },
  {
    "path": "diffsynth/prompters/kolors_prompter.py",
    "chars": 14110,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.model_manager import ModelManager\nimport json, os, re\nfrom typing "
  },
  {
    "path": "diffsynth/prompters/omnigen_prompter.py",
    "chars": 15398,
    "preview": "import os\nimport re\nfrom typing import Dict, List\n\nimport torch\nfrom PIL import Image\nfrom torchvision import transforms"
  },
  {
    "path": "diffsynth/prompters/omost.py",
    "chars": 15938,
    "preview": "from transformers import AutoTokenizer, TextIteratorStreamer\nimport difflib\nimport torch\nimport numpy as np\nimport re\nfr"
  },
  {
    "path": "diffsynth/prompters/prompt_refiners.py",
    "chars": 6568,
    "preview": "from transformers import AutoTokenizer\nfrom ..models.model_manager import ModelManager\nimport torch\nfrom .omost import O"
  },
  {
    "path": "diffsynth/prompters/sd3_prompter.py",
    "chars": 3901,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.model_manager import ModelManager\nfrom ..models import SD3TextEnco"
  },
  {
    "path": "diffsynth/prompters/sd_prompter.py",
    "chars": 3461,
    "preview": "from .base_prompter import BasePrompter, tokenize_long_prompt\nfrom ..models.utils import load_state_dict, search_for_emb"
  },
  {
    "path": "diffsynth/prompters/sdxl_prompter.py",
    "chars": 2451,
    "preview": "from .base_prompter import BasePrompter, tokenize_long_prompt\nfrom ..models.model_manager import ModelManager\nfrom ..mod"
  },
  {
    "path": "diffsynth/prompters/stepvideo_prompter.py",
    "chars": 2086,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder\nfrom ..m"
  },
  {
    "path": "diffsynth/prompters/wan_prompter.py",
    "chars": 3477,
    "preview": "from .base_prompter import BasePrompter\nfrom ..models.wan_video_text_encoder import WanTextEncoder\nfrom transformers imp"
  },
  {
    "path": "diffsynth/schedulers/__init__.py",
    "chars": 134,
    "preview": "from .ddim import EnhancedDDIMScheduler\nfrom .continuous_ode import ContinuousODEScheduler\nfrom .flow_match import FlowM"
  },
  {
    "path": "diffsynth/schedulers/continuous_ode.py",
    "chars": 2462,
    "preview": "import torch\n\n\nclass ContinuousODEScheduler():\n\n    def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_m"
  },
  {
    "path": "diffsynth/schedulers/ddim.py",
    "chars": 5087,
    "preview": "import torch, math\n\n\nclass EnhancedDDIMScheduler():\n\n    def __init__(self, num_train_timesteps=1000, beta_start=0.00085"
  },
  {
    "path": "diffsynth/schedulers/flow_match.py",
    "chars": 3347,
    "preview": "import torch\n\n\n\nclass FlowMatchScheduler():\n\n    def __init__(self, num_inference_steps=100, num_train_timesteps=1000, s"
  },
  {
    "path": "diffsynth/tokenizer_configs/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json",
    "chars": 2593,
    "preview": "{\n  \"<extra_id_0>\": 32099,\n  \"<extra_id_10>\": 32089,\n  \"<extra_id_11>\": 32088,\n  \"<extra_id_12>\": 32087,\n  \"<extra_id_13"
  },
  {
    "path": "diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json",
    "chars": 2543,
    "preview": "{\n  \"additional_special_tokens\": [\n    \"<extra_id_0>\",\n    \"<extra_id_1>\",\n    \"<extra_id_2>\",\n    \"<extra_id_3>\",\n    \""
  },
  {
    "path": "diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json",
    "chars": 20617,
    "preview": "{\n  \"add_prefix_space\": true,\n  \"added_tokens_decoder\": {\n    \"0\": {\n      \"content\": \"<pad>\",\n      \"lstrip\": false,\n  "
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt",
    "chars": 515308,
    "preview": "#version: 0.2\ni n\nt h\na n\nr e\na r\ne r\nth e</w>\nin g</w>\no u\no n\ns t\no r\ne n\no n</w>\na l\na t\ne r</w>\ni t\ni n</w>\nt o</w>\n"
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_1/special_tokens_map.json",
    "chars": 588,
    "preview": "{\n  \"bos_token\": {\n    \"content\": \"<|startoftext|>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n "
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_1/tokenizer_config.json",
    "chars": 705,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"added_tokens_decoder\": {\n    \"49406\": {\n      \"content\": \"<|startoftext|>\",\n      \"lst"
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json",
    "chars": 1050327,
    "preview": "{\n  \"!\": 0,\n  \"!!\": 1443,\n  \"!!!\": 11194,\n  \"!!!!\": 4003,\n  \"!!!!!!!!\": 11281,\n  \"!!!!!!!!!!!!!!!!\": 30146,\n  \"!!!!!!!!!"
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_2/special_tokens_map.json",
    "chars": 2543,
    "preview": "{\n  \"additional_special_tokens\": [\n    \"<extra_id_0>\",\n    \"<extra_id_1>\",\n    \"<extra_id_2>\",\n    \"<extra_id_3>\",\n    \""
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json",
    "chars": 2377466,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": null,\n  \"padding\": null,\n  \"added_tokens\": [\n    {\n      \"id\": 0,\n      \"content\":"
  },
  {
    "path": "diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer_config.json",
    "chars": 20817,
    "preview": "{\n  \"add_prefix_space\": true,\n  \"added_tokens_decoder\": {\n    \"0\": {\n      \"content\": \"<pad>\",\n      \"lstrip\": false,\n  "
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json",
    "chars": 125,
    "preview": "{\n  \"cls_token\": \"[CLS]\",\n  \"mask_token\": \"[MASK]\",\n  \"pad_token\": \"[PAD]\",\n  \"sep_token\": \"[SEP]\",\n  \"unk_token\": \"[UNK"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json",
    "chars": 559,
    "preview": "{\n  \"cls_token\": \"[CLS]\",\n  \"do_basic_tokenize\": true,\n  \"do_lower_case\": true,\n  \"mask_token\": \"[MASK]\",\n  \"name_or_pat"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt",
    "chars": 281526,
    "preview": "[PAD]\n[unused1]\n[unused2]\n[unused3]\n[unused4]\n[unused5]\n[unused6]\n[unused7]\n[unused8]\n[unused9]\n[unused10]\n[unused11]\n[u"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt",
    "chars": 75770,
    "preview": "[PAD]\n[unused1]\n[unused2]\n[unused3]\n[unused4]\n[unused5]\n[unused6]\n[unused7]\n[unused8]\n[unused9]\n[unused10]\n[unused11]\n[u"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json",
    "chars": 688,
    "preview": "{\n  \"_name_or_path\": \"/home/patrick/t5/mt5-xl\",\n  \"architectures\": [\n    \"MT5ForConditionalGeneration\"\n  ],\n  \"d_ff\": 51"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json",
    "chars": 65,
    "preview": "{\"eos_token\": \"</s>\", \"unk_token\": \"<unk>\", \"pad_token\": \"<pad>\"}"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json",
    "chars": 248,
    "preview": "{\"eos_token\": \"</s>\", \"unk_token\": \"<unk>\", \"pad_token\": \"<pad>\", \"extra_ids\": 0, \"additional_special_tokens\": null, \"sp"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt",
    "chars": 515308,
    "preview": "#version: 0.2\ni n\nt h\na n\nr e\na r\ne r\nth e</w>\nin g</w>\no u\no n\ns t\no r\ne n\no n</w>\na l\na t\ne r</w>\ni t\ni n</w>\nt o</w>\n"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/special_tokens_map.json",
    "chars": 588,
    "preview": "{\n  \"bos_token\": {\n    \"content\": \"<|startoftext|>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n "
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/tokenizer_config.json",
    "chars": 705,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"added_tokens_decoder\": {\n    \"49406\": {\n      \"content\": \"<|startoftext|>\",\n      \"lst"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json",
    "chars": 1050327,
    "preview": "{\n  \"!\": 0,\n  \"!!\": 1443,\n  \"!!!\": 11194,\n  \"!!!!\": 4003,\n  \"!!!!!!!!\": 11281,\n  \"!!!!!!!!!!!!!!!!\": 30146,\n  \"!!!!!!!!!"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json",
    "chars": 819,
    "preview": "{\n  \"_valid_processor_keys\": [\n    \"images\",\n    \"do_resize\",\n    \"size\",\n    \"resample\",\n    \"do_center_crop\",\n    \"cro"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/special_tokens_map.json",
    "chars": 577,
    "preview": "{\n  \"bos_token\": {\n    \"content\": \"<|begin_of_text|>\",\n    \"lstrip\": false,\n    \"normalized\": false,\n    \"rstrip\": false"
  },
  {
    "path": "diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer_config.json",
    "chars": 51699,
    "preview": "{\n  \"add_bos_token\": true,\n  \"add_eos_token\": false,\n  \"add_prefix_space\": null,\n  \"added_tokens_decoder\": {\n    \"128000"
  },
  {
    "path": "diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json",
    "chars": 249,
    "preview": "{\n  \"name_or_path\": \"THUDM/chatglm3-6b-base\",\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"C"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt",
    "chars": 515308,
    "preview": "#version: 0.2\ni n\nt h\na n\nr e\na r\ne r\nth e</w>\nin g</w>\no u\no n\ns t\no r\ne n\no n</w>\na l\na t\ne r</w>\ni t\ni n</w>\nt o</w>\n"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json",
    "chars": 472,
    "preview": "{\n  \"bos_token\": {\n    \"content\": \"<|startoftext|>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n "
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json",
    "chars": 806,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"bos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<|startoftext|>\",\n    \"lstri"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json",
    "chars": 1050327,
    "preview": "{\n  \"!\": 0,\n  \"!!\": 1443,\n  \"!!!\": 11194,\n  \"!!!!\": 4003,\n  \"!!!!!!!!\": 11281,\n  \"!!!!!!!!!!!!!!!!\": 30146,\n  \"!!!!!!!!!"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt",
    "chars": 515308,
    "preview": "#version: 0.2\ni n\nt h\na n\nr e\na r\ne r\nth e</w>\nin g</w>\no u\no n\ns t\no r\ne n\no n</w>\na l\na t\ne r</w>\ni t\ni n</w>\nt o</w>\n"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json",
    "chars": 588,
    "preview": "{\n  \"bos_token\": {\n    \"content\": \"<|startoftext|>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n "
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json",
    "chars": 705,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"added_tokens_decoder\": {\n    \"49406\": {\n      \"content\": \"<|startoftext|>\",\n      \"lst"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json",
    "chars": 1050327,
    "preview": "{\n  \"!\": 0,\n  \"!!\": 1443,\n  \"!!!\": 11194,\n  \"!!!!\": 4003,\n  \"!!!!!!!!\": 11281,\n  \"!!!!!!!!!!!!!!!!\": 30146,\n  \"!!!!!!!!!"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt",
    "chars": 515308,
    "preview": "#version: 0.2\ni n\nt h\na n\nr e\na r\ne r\nth e</w>\nin g</w>\no u\no n\ns t\no r\ne n\no n</w>\na l\na t\ne r</w>\ni t\ni n</w>\nt o</w>\n"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json",
    "chars": 576,
    "preview": "{\n  \"bos_token\": {\n    \"content\": \"<|startoftext|>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n "
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json",
    "chars": 856,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"added_tokens_decoder\": {\n    \"0\": {\n      \"content\": \"!\",\n      \"lstrip\": false,\n     "
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json",
    "chars": 1050327,
    "preview": "{\n  \"!\": 0,\n  \"!!\": 1443,\n  \"!!!\": 11194,\n  \"!!!!\": 4003,\n  \"!!!!!!!!\": 11281,\n  \"!!!!!!!!!!!!!!!!\": 30146,\n  \"!!!!!!!!!"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json",
    "chars": 2543,
    "preview": "{\n  \"additional_special_tokens\": [\n    \"<extra_id_0>\",\n    \"<extra_id_1>\",\n    \"<extra_id_2>\",\n    \"<extra_id_3>\",\n    \""
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json",
    "chars": 2377266,
    "preview": "{\n  \"version\": \"1.0\",\n  \"truncation\": null,\n  \"padding\": null,\n  \"added_tokens\": [\n    {\n      \"id\": 0,\n      \"content\":"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json",
    "chars": 20617,
    "preview": "{\n  \"add_prefix_space\": true,\n  \"added_tokens_decoder\": {\n    \"0\": {\n      \"content\": \"<pad>\",\n      \"lstrip\": false,\n  "
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt",
    "chars": 418794,
    "preview": "#version: 0.2\ni n\nt h\na n\nr e\na r\ne r\nth e</w>\nin g</w>\no u\no n\ns t\no r\ne n\no n</w>\na l\na t\ne r</w>\ni t\ni n</w>\nt o</w>\n"
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json",
    "chars": 503,
    "preview": "{\n    \"bos_token\": {\n      \"content\": \"<|startoftext|>\",\n      \"lstrip\": false,\n      \"normalized\": true,\n      \"rstrip\""
  },
  {
    "path": "diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json",
    "chars": 854,
    "preview": "{\n  \"add_prefix_space\": false,\n  \"added_tokens_decoder\": {\n    \"0\": {\n      \"content\": \"!\",\n      \"lstrip\": false,\n     "
  }
]

// ... and 22 more files (download for full content)

About this extraction

This page contains the full source code of the KwaiVGI/SynCamMaster GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 222 files (32.7 MB), approximately 4.3M tokens, and a symbol index with 2289 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!