main 3a27592d5eab cached
125 files
9.2 MB
2.4M tokens
695 symbols
1 requests
Download .txt
Showing preview only (9,626K chars total). Download the full file or copy to clipboard to get everything.
Repository: ArmastusChen/inverse_painting
Branch: main
Commit: 3a27592d5eab
Files: 125
Total size: 9.2 MB

Directory structure:
gitextract_kiog4jd6/

├── README.md
├── configs/
│   ├── inference/
│   │   └── inference.yaml
│   └── train/
│       ├── train_mask_gen.yaml
│       └── train_renderer.yaml
├── data/
│   ├── sample_data/
│   │   └── train/
│   │       └── rgb/
│   │           └── example/
│   │               └── last_aligned_frame_inv.json
│   └── sample_data_processed/
│       └── train/
│           ├── llava_json.json
│           ├── rgb/
│           │   └── example/
│           │       └── last_aligned_frame_inv.json
│           └── text/
│               └── example/
│                   ├── 2_0:40.json
│                   ├── 3_0:53.json
│                   ├── 4_1:12.json
│                   ├── 5_1:28.json
│                   ├── 6_1:58.json
│                   ├── 7_2:19.json
│                   ├── 8_2:36.json
│                   ├── 9_2:43.json
│                   └── white_10_3:21.json
├── data_processing/
│   ├── run_llava/
│   │   ├── main.py
│   │   ├── make_list.py
│   │   └── utils.py
│   └── run_lpips/
│       └── main.py
├── dataset/
│   └── dataset.py
├── demo.py
├── models/
│   ├── ReferenceEncoder.py
│   ├── ReferenceNet.py
│   ├── ReferenceNet_attention.py
│   ├── ReferenceNet_attention_fp16.py
│   ├── attention.py
│   ├── clip_adapter.py
│   ├── hack_cur_image_guider.py
│   ├── hack_unet2d.py
│   ├── image_processor.py
│   ├── orig_attention.py
│   ├── positional_encoder.py
│   ├── resnet.py
│   ├── unet.py
│   └── unet_3d_blocks.py
├── pipelines/
│   ├── context.py
│   └── pipeline_stage_1.py
├── requirements.txt
├── training_scripts/
│   ├── llava/
│   │   ├── __init__.py
│   │   ├── constants.py
│   │   ├── conversation.py
│   │   ├── eval/
│   │   │   ├── eval_gpt_review.py
│   │   │   ├── eval_gpt_review_bench.py
│   │   │   ├── eval_gpt_review_visual.py
│   │   │   ├── eval_pope.py
│   │   │   ├── eval_science_qa.py
│   │   │   ├── eval_science_qa_gpt4.py
│   │   │   ├── eval_science_qa_gpt4_requery.py
│   │   │   ├── eval_textvqa.py
│   │   │   ├── generate_webpage_data_from_table.py
│   │   │   ├── m4c_evaluator.py
│   │   │   ├── model_qa.py
│   │   │   ├── model_vqa.py
│   │   │   ├── model_vqa_loader.py
│   │   │   ├── model_vqa_mmbench.py
│   │   │   ├── model_vqa_science.py
│   │   │   ├── qa_baseline_gpt35.py
│   │   │   ├── run_llava.py
│   │   │   ├── summarize_gpt_review.py
│   │   │   ├── table/
│   │   │   │   ├── answer/
│   │   │   │   │   ├── answer_alpaca-13b.jsonl
│   │   │   │   │   ├── answer_bard.jsonl
│   │   │   │   │   ├── answer_gpt35.jsonl
│   │   │   │   │   ├── answer_llama-13b.jsonl
│   │   │   │   │   └── answer_vicuna-13b.jsonl
│   │   │   │   ├── caps_boxes_coco2014_val_80.jsonl
│   │   │   │   ├── model.jsonl
│   │   │   │   ├── prompt.jsonl
│   │   │   │   ├── question.jsonl
│   │   │   │   ├── results/
│   │   │   │   │   ├── test_sqa_llava_13b_v0.json
│   │   │   │   │   └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│   │   │   │   ├── review/
│   │   │   │   │   ├── review_alpaca-13b_vicuna-13b.jsonl
│   │   │   │   │   ├── review_bard_vicuna-13b.jsonl
│   │   │   │   │   ├── review_gpt35_vicuna-13b.jsonl
│   │   │   │   │   └── review_llama-13b_vicuna-13b.jsonl
│   │   │   │   ├── reviewer.jsonl
│   │   │   │   └── rule.json
│   │   │   └── webpage/
│   │   │       ├── index.html
│   │   │       ├── script.js
│   │   │       └── styles.css
│   │   ├── mm_utils.py
│   │   ├── model/
│   │   │   ├── __init__.py
│   │   │   ├── apply_delta.py
│   │   │   ├── builder.py
│   │   │   ├── consolidate.py
│   │   │   ├── language_model/
│   │   │   │   ├── llava_llama.py
│   │   │   │   ├── llava_mistral.py
│   │   │   │   └── llava_mpt.py
│   │   │   ├── llava_arch.py
│   │   │   ├── make_delta.py
│   │   │   ├── multimodal_encoder/
│   │   │   │   ├── builder.py
│   │   │   │   └── clip_encoder.py
│   │   │   ├── multimodal_projector/
│   │   │   │   └── builder.py
│   │   │   └── utils.py
│   │   ├── serve/
│   │   │   ├── __init__.py
│   │   │   ├── cli.py
│   │   │   ├── controller.py
│   │   │   ├── gradio_web_server.py
│   │   │   ├── model_worker.py
│   │   │   ├── register_worker.py
│   │   │   ├── sglang_worker.py
│   │   │   └── test_message.py
│   │   ├── train/
│   │   │   ├── llama_flash_attn_monkey_patch.py
│   │   │   ├── llama_xformers_attn_monkey_patch.py
│   │   │   ├── llava_trainer.py
│   │   │   ├── train.py
│   │   │   ├── train_mem.py
│   │   │   └── train_xformers.py
│   │   └── utils.py
│   ├── merge_ckpt.sh
│   ├── scripts/
│   │   ├── merge_lora_weights.py
│   │   └── zero2.json
│   ├── train_mask_generator.py
│   ├── train_renderer.py
│   └── train_text_generator.sh
├── unet_2d/
│   ├── attention.py
│   ├── resnet.py
│   ├── unet_2d_blocks.py
│   └── unet_2d_condition.py
└── utils/
    ├── __init__.py
    ├── dist_tools.py
    ├── inference_helpers.py
    ├── llava_utils.py
    ├── text_wrapper.py
    └── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
<h1 align='Center'>Inverse Painting: Reconstructing The Painting Process</h1>

<div align='Center'>
            <a href="https://homes.cs.washington.edu/~boweiche/">Bowei Chen</a>&emsp;
            <a href="https://scholar.google.com/citations?user=R3sUe_EAAAAJ&hl=en">Yifan Wang</a>&emsp;
            <a href="https://homes.cs.washington.edu/~curless/">Brian Curless</a>&emsp;
            <a href="https://www.irakemelmacher.com">Ira Kemelmacher-Shlizerman</a>&emsp;
            <a href="https://www.smseitz.com">Steven M. Seitz</a>&emsp;
</div>
<div align='Center'>
    University of Washington
</div>
<div align='Center'>
<i><strong><a href='https://asia.siggraph.org/2024/' target='_blank'>SIGGRAPH Asia 2024</a></strong></i>
</div>

<div align='Center'>
    <a href='https://inversepainting.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
    <a href='https://arxiv.org/abs/2409.20556'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
    <a href='https://youtu.be/T89auOvTm0o'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
</div>



# Installation

The code can be run under environment with Python 3.10, pytorch 2.1.2 and cuda 11.8.  (It should run with other versions, but we have not tested it).  

We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) to set up an environment:

    conda create --name inverse_painting python=3.10

    conda activate inverse_painting

Install the required packages:

    pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118

    pip install -r requirements.txt

Install LLaVA

    git clone https://github.com/haotian-liu/LLaVA.git
    cd LLaVA
    pip install -e .
    cd ..

Install xformers

    pip3 install -U xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118

# Inference
We provide demo code to run our pretrained models on any target landscape painting.


## Download Pretrained Models
Download pretrained models either from [Huggingface](https://huggingface.co/boweiche/inverse_painting) or [Google Drive](https://drive.google.com/drive/folders/1exu6Ws-NIZO-3qNO5s50b71fSQALkdvK?usp=drive_link), and then put them into the root folder.  We recommend using the following commands for downloading from Huggingface:

    git lfs install
    git clone https://huggingface.co/boweiche/inverse_painting


After downloading, the pretrained models should be organized as follows:
```text
./checkpoints/
|-- renderer
|-- RP
|-- TP_llava
|-- TP_llava_annotator    # optional, only required for training. 

./base_ckpt/
|-- clip-vit-base-patch32
|-- realisticVisionV51_v51VAE
```

## Run Demo

For demo, we provide several target paintings in `./data/demo`. You can run the demo code using the following command
```shell
python demo.py
```

The generated results will be saved in `results`. 

# Training 

The text generator, mask generator, and renderer are trained separately. You can train these models simutineously because GT text and mask instructions, instead of predicted ones, will be used to train mask generator and renderer. 


## Dataset Pre-Processing

We provide an example of sample data in `data/sample_data`. Belows are the data structure before running the data pre-processing. 

```text
./data/sample_data/train
|-- rgb/                                   # folders of training samples 
   |-- example/                            # name of the sample
      |-- {ind}_{time}.jpg                 # name of each frame, {ind} is the frame index and {time} indicates the timestamp of the frame within the video.
      |-- last_aligned_frame_inv.json      # a json file to define the frame as target painting. 
```


The following steps help you to pre-process this sample data for training, including the preparation of the GT text and mask instructions.  
You can refer to the code in `data_processing`. We also provide our processed data in `data/sample_data_processed` for your reference. 


### Prepare Text Instruction
This step prepares the ground truth (GT) text instructions by feeding both the GT current canvas and the GT next canvas into the LLaVA model.

In this codebase, rather than using the pretrained LLaVA model online, we utilize our fine-tuned version of LLaVA for more accurate GT text generation. If you have not downloaded `TP_llava_annotator` in the previous step, you can download it from [Google Drive](https://drive.google.com/drive/folders/1Lj4pSlHJTXvJdyXBWbOT6u-ZhiOyyWQG?usp=drive_link) and put it into the folder `checkpoints`. This model has been fine-tuned using the GT current image, GT next image, and GT text from our dataset, with any inaccurate GT text manually corrected.

You can now run the following commands to prepare GT text instruction. 
```shell
cd data_processing/run_llava      
python main.py   --save_vis   --model_path  ../../checkpoints/TP_llava_annotator    # you can remove --save_vis if you don't want the visualization
python make_list.py      # prepare the data format for the training of text generator
cd ../../
```

The generated text will be saved in `data/sample_data/train/text` and `data/sample_data/train/text_vis` (if you use --save_vis). 
The training data in the format of LLaVA is in `data/sample_data/train/llava_image` and `data/sample_data/train/llava_json.json`

### Prepare Mask Instruction
This step prepares the GT text instructions by computing the LPIPS difference between GT current and next canvas. 

You can now run the following commands to prepare mask text instruction. 
```shell
cd data_processing/run_lpips      
python main.py   --save_vis     # you can remove --save_vis if you don't want the visualization
cd ../../
```
The generated mask will be saved in `data/sample_data/train/lpips` and `data/sample_data/train/lpips_vis` (if you use --save_vis). 



## Training Two-Stage Pipeline
The training code is in `training_scripts`. The following three models can be trained in any order because they are not dependent on each other. 


### Train Text Generator 
```shell
cd training_scripts
bash train_text_generator.sh     # This trains a lora of LLaVA, saved in `./checkpoints/llava-v1.5-7b-task-lora`. It will complete very fast because the sample dataset is very small
bash merge_ckpt.sh               # After training, merge the lora with the base model, saved in `./checkpoints/llava-v1.5-7b-task-lora_final`
cd ..
```

### Train Mask Generator 
```shell
cd training_scripts
torchrun --nnodes=1 --nproc_per_node=1  --master_port=25678 train_mask_generator.py    --config  ../configs/train/train_mask_gen.yaml   
cd ..
```
This trains a Unet with cross-attention layers, saved in `./outputs/mask_gen`


### Train Next Frame Renderer
```shell
cd training_scripts
torchrun --nnodes=1 --nproc_per_node=1  --master_port=12678  train_renderer.py    --config   ../configs/train/train_renderer.yaml   
cd ..
```
The output will be saved in `./outputs/renderer`




## Acknowledgement

This codebase is adpated from [diffusers](https://github.com/huggingface/diffusers), [Open-AnimateAnyone
](https://github.com/guoqincode/Open-AnimateAnyone), and [LLaVA](https://github.com/haotian-liu/LLaVA).



# Disclaimer

We tested this codebase on a single NVIDIA A40 GPU. The result produced by this code might be slightly different when running on a different machine. 



# Citation

If you find our work useful for your research, please consider citing the paper:

```
@inproceedings{chen2024inverse,
  title={Inverse Painting: Reconstructing The Painting Process},
  author={Chen, Bowei and Wang, Yifan and Curless, Brian and Kemelmacher-Shlizerman, Ira and Seitz, Steven M},
  booktitle={SIGGRAPH Asia 2024 Conference Papers},
  year={2024}
}
```

================================================
FILE: configs/inference/inference.yaml
================================================
unet_additional_kwargs:
  unet_use_cross_frame_attention: false
  unet_use_temporal_attention: false
  use_motion_module: true
  motion_module_resolutions:
  - 1
  - 2
  - 4
  - 8
  motion_module_mid_block: false
  motion_module_decoder_only: false
  motion_module_type: Vanilla
  motion_module_kwargs:
    num_attention_heads: 8
    num_transformer_block: 1
    attention_block_types:
    - Temporal_Self
    - Temporal_Self
    temporal_position_encoding: true
    temporal_position_encoding_max_len: 24
    temporal_attention_dim_div: 1

noise_scheduler_kwargs:
  beta_start: 0.00085
  beta_end: 0.012
  beta_schedule: "linear"


================================================
FILE: configs/train/train_mask_gen.yaml
================================================
image_finetune: true

output_dir: "outputs/mask_gen"
pretrained_model_path: "../base_ckpt/realisticVisionV51_v51VAE"

noise_scheduler_kwargs:
  num_train_timesteps: 1000
  beta_start:          0.00085
  beta_end:            0.012
  # beta_schedule:       "scaled_linear"
  beta_schedule:       "linear"
  steps_offset:        1
  clip_sample:         false

description: "### Train Info: Mask Generator ###"

unet_additional_kwargs:
  use_motion_module              : true
  motion_module_resolutions      : [ 1,2,4,8 ]
  unet_use_cross_frame_attention : false
  unet_use_temporal_attention    : false

  motion_module_type: Vanilla
  motion_module_kwargs:
    num_attention_heads                : 8
    num_transformer_block              : 1
    attention_block_types              : [ "Temporal_Self", "Temporal_Self" ]
    temporal_position_encoding         : true
    temporal_position_encoding_max_len : 24
    temporal_attention_dim_div         : 1
    zero_initialize                    : true


train_data:
  data_folder: "../data/sample_data"
  sample_size:  512 # for 40G 256
  sample_stride: 1
  sample_num: 100
  sample_n_frames: 8
  clip_model_path: "../base_ckpt/clip-vit-base-patch32"
  # pad_to_square+resize, resize, pad_to_8, pad_to_16
  pad_mode: "pad_to_16"



trainable_modules:
  - "."

unet_checkpoint_path: ""


learning_rate:    1.e-5
train_batch_size: 1

use_PE: True
PE_type: 'abs'
PE_time_max: 100
PE_time_interval: 5


max_train_epoch:      -1
max_train_steps:      600000
checkpointing_epochs: -1
checkpointing_steps:  20000
gradient_accumulation_steps: 1

feature_type: 'lpips'

additional_input: 'lpips_diff+text'
binary_output: True  
binary_threshold: 0.2

win_size: 5

validation_steps:       1000
validation_steps_tuple: [2, 50]

global_seed: 42
mixed_precision_training: True 
enable_xformers_memory_efficient_attention: True


is_debug: False


================================================
FILE: configs/train/train_renderer.yaml
================================================
image_finetune: true

output_dir: "outputs/renderer"
pretrained_model_path: "../base_ckpt/realisticVisionV51_v51VAE"
clip_model_path: "../base_ckpt/clip-vit-base-patch32"

cur_image_guider_checkpoint_path: ""
referencenet_checkpoint_path: ""

noise_scheduler_kwargs:
  num_train_timesteps: 1000
  beta_start:          0.00085
  beta_end:            0.012
  beta_schedule:       "linear"
  steps_offset:        1
  clip_sample:         false

description: "### Train Info: Renderer ###"

unet_additional_kwargs:
  use_motion_module              : true
  motion_module_resolutions      : [ 1,2,4,8 ]
  unet_use_cross_frame_attention : false
  unet_use_temporal_attention    : false

  motion_module_type: Vanilla
  motion_module_kwargs:
    num_attention_heads                : 8
    num_transformer_block              : 1
    attention_block_types              : [ "Temporal_Self", "Temporal_Self" ]
    temporal_position_encoding         : true
    temporal_position_encoding_max_len : 24
    temporal_attention_dim_div         : 1
    zero_initialize                    : true


train_data:
  data_folder: "../data/sample_data"
  sample_size:  512 # for 40G 256
  sample_stride: 1
  sample_num: 100
  sample_n_frames: 8
  clip_model_path: "../base_ckpt/clip-vit-base-patch32"

  # pad_to_square+resize, resize, pad_to_16
  pad_mode: "pad_to_16"


trainable_modules:
  # - "motion_modules."
  - "."

unet_checkpoint_path: ""

fusion_blocks: "full"

learning_rate:    1.e-5
train_batch_size: 1
use_PE: True
PE_type: 'abs'
use_RP: True
use_TP: True 
no_refnet: False 
win_size: 10
use_diff_clip: True



feature_type: 'text'
TP_ckpt_path: ''
RP_fusion_type: 'spatial'
use_mask_for_loss: False
use_binary_RP: True  
RP_threshold: 0.2

PE_time_max: 100
PE_time_interval: 5


is_train_TP: False 
is_train_text_encoder: False

max_train_epoch:      -1
max_train_steps:      600000
checkpointing_epochs: -1
checkpointing_steps:  20000
gradient_accumulation_steps: 1


global_seed: 42
mixed_precision_training: True 
enable_xformers_memory_efficient_attention: True


is_debug: False


================================================
FILE: data/sample_data/train/rgb/example/last_aligned_frame_inv.json
================================================
{
    "10_3:21": [
        "1_0:15",
        "2_0:40",
        "3_0:53",
        "4_1:12",
        "5_1:28",
        "6_1:58",
        "7_2:19",
        "8_2:36",
        "9_2:43"
    ]
}

================================================
FILE: data/sample_data_processed/train/llava_json.json
================================================
[
    {
        "id": 0,
        "image": "example/white_10_3:21_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "Mountain, sky"
            }
        ]
    },
    {
        "id": 1,
        "image": "example/2_0:40_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "Mountain"
            }
        ]
    },
    {
        "id": 2,
        "image": "example/3_0:53_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "mountain, ground"
            }
        ]
    },
    {
        "id": 3,
        "image": "example/4_1:12_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "Grass"
            }
        ]
    },
    {
        "id": 4,
        "image": "example/5_1:28_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "Grass"
            }
        ]
    },
    {
        "id": 5,
        "image": "example/6_1:58_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "Grass"
            }
        ]
    },
    {
        "id": 6,
        "image": "example/7_2:19_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "Water"
            }
        ]
    },
    {
        "id": 7,
        "image": "example/8_2:36_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "water"
            }
        ]
    },
    {
        "id": 8,
        "image": "example/9_2:43_10_3:21.png",
        "conversations": [
            {
                "from": "human",
                "value": "<image>\nThere are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words."
            },
            {
                "from": "gpt",
                "value": "reflection"
            }
        ]
    }
]

================================================
FILE: data/sample_data_processed/train/rgb/example/last_aligned_frame_inv.json
================================================
{
    "10_3:21": [
        "1_0:15",
        "2_0:40",
        "3_0:53",
        "4_1:12",
        "5_1:28",
        "6_1:58",
        "7_2:19",
        "8_2:36",
        "9_2:43"
    ]
}

================================================
FILE: data/sample_data_processed/train/text/example/2_0:40.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "2_0:40", "next_image_name": "3_0:53", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "Mountain"}

================================================
FILE: data/sample_data_processed/train/text/example/3_0:53.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "3_0:53", "next_image_name": "4_1:12", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "mountain, ground"}

================================================
FILE: data/sample_data_processed/train/text/example/4_1:12.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "4_1:12", "next_image_name": "5_1:28", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "Grass"}

================================================
FILE: data/sample_data_processed/train/text/example/5_1:28.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "5_1:28", "next_image_name": "6_1:58", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "Grass"}

================================================
FILE: data/sample_data_processed/train/text/example/6_1:58.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "6_1:58", "next_image_name": "7_2:19", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "Grass"}

================================================
FILE: data/sample_data_processed/train/text/example/7_2:19.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "7_2:19", "next_image_name": "8_2:36", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "Water"}

================================================
FILE: data/sample_data_processed/train/text/example/8_2:36.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "8_2:36", "next_image_name": "9_2:43", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "water"}

================================================
FILE: data/sample_data_processed/train/text/example/9_2:43.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "9_2:43", "next_image_name": "10_3:21", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "reflection"}

================================================
FILE: data/sample_data_processed/train/text/example/white_10_3:21.json
================================================
{"ref_img_name": "10_3:21", "cur_image_name": "white_10_3:21", "next_image_name": "2_0:40", "prompt": "There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", "next_text": "Mountain, sky"}

================================================
FILE: data_processing/run_llava/main.py
================================================
import argparse
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from utils import Predictor

import cv2
import glob
import os
import json
import numpy as np
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
import tqdm
import matplotlib.pyplot as plt


def main(args):
    datadata_folder = args.data_folder
    split = args.split
    model_path = args.model_path
    sample_num = args.sample_num
    save_vis = args.save_vis
    prompt = args.prompt

    cache_dir = 'cache'
    os.makedirs(cache_dir, exist_ok=True)

    video_dirs = glob.glob(f'{datadata_folder}/{split}/rgb/*')
    video_dirs = sorted(video_dirs)

    dst_dir = f'{datadata_folder}/{split}/text'
    dst_vis_dir = f'{datadata_folder}/{split}/text_vis'

    args_obj = type('Args', (), {
        "model_path": model_path,
        "model_base": None,
        "model_name": get_model_name_from_path(model_path),
        "query": prompt,
        "conv_mode": None,
        "image_file": None,
        "sep": ",",
        "temperature": 0,
        "top_p": None,
        "num_beams": 1,
        "max_new_tokens": 512
    })()

    predictor = Predictor(args_obj)

    cache_path = f'{cache_dir}/cache.png'

    for video_dir in tqdm.tqdm(video_dirs[:]):
        last_aligned_frame_inv_path = f'{video_dir}/last_aligned_frame_inv.json'
        video_name = os.path.basename(video_dir)

        # Load JSON
        with open(last_aligned_frame_inv_path) as f:
            last_aligned_frame_inv_path_dict = json.load(f)

        for ref_img_name in list(last_aligned_frame_inv_path_dict.keys()):
            canvas_candidate_list_full = last_aligned_frame_inv_path_dict[ref_img_name]

            ref_image = Image.open(f'{video_dir}/{ref_img_name}.jpg')

            canvas_candidate_list_full = canvas_candidate_list_full[1:]
            canvas_candidate_list = ['white']
            if len(canvas_candidate_list_full) >= (sample_num - 1):
                sample_inds = np.round(np.linspace(0, len(canvas_candidate_list_full) - 1, sample_num - 1)).astype(int)
                canvas_candidate_list += [canvas_candidate_list_full[i] for i in sample_inds]
                canvas_candidate_list.append(ref_img_name)
                assert len(canvas_candidate_list) == (sample_num + 1)
            else:
                canvas_candidate_list += canvas_candidate_list_full
                canvas_candidate_list.append(ref_img_name)

            for i in range(len(canvas_candidate_list[:-1])):
                cur_image_name = canvas_candidate_list[i]
                next_image_name = canvas_candidate_list[i + 1]

                if cur_image_name == 'white':
                    cur_image = Image.new('RGB', (ref_image.width, ref_image.height), (255, 255, 255))
                    cur_image_name = f'white_{ref_img_name}'
                else:
                    cur_image_path = f'{video_dir}/{cur_image_name}.jpg'
                    cur_image = Image.open(cur_image_path)

                next_image_path = f'{video_dir}/{next_image_name}.jpg'
                next_image = Image.open(next_image_path)

                # Horizontal concat
                canvas = Image.new('RGB', (ref_image.width * 2, ref_image.height))
                canvas.paste(cur_image, (0, 0))
                canvas.paste(next_image, (ref_image.width, 0))

                canvas.save(cache_path)

                args_obj.image_file = cache_path
                cur_prompt = prompt
                args_obj.query = cur_prompt

                predictor.set_args(args_obj)
                out_text = predictor.eval_model()

                os.makedirs(f"{dst_dir}/{video_name}", exist_ok=True)

                # Save JSON
                save_dict = {
                    'ref_img_name': ref_img_name,
                    'cur_image_name': cur_image_name,
                    'next_image_name': next_image_name,
                    'prompt': cur_prompt,
                    'next_text': out_text,
                }

                with open(f"{dst_dir}/{video_name}/{cur_image_name}.json", 'w') as f:
                    json.dump(save_dict, f)

                if save_vis:
                    canvas = np.array(canvas)
                    plt.imshow(canvas)
                    plt.axis('off')
                    plt.text(0, 0, out_text, fontsize=12, color='red', fontweight='bold')
                    os.makedirs(f"{dst_vis_dir}/{video_name}", exist_ok=True)
                    plt.savefig(f"{dst_vis_dir}/{video_name}/{cur_image_name}.jpg", bbox_inches='tight', pad_inches=0)
                    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process painting steps using a pretrained model.")
    parser.add_argument("--data_folder", type=str, default='../../data/sample_data', help="Path to the data folder.")
    parser.add_argument("--split", type=str, default='train', choices=["train", "val", "test"], help="Data split to process.")
    parser.add_argument("--model_path", type=str, default='../../checkpoints/TP_llava_annotator', help="Path to the pretrained model.")
    parser.add_argument("--sample_num", type=int, default=100, help="Number of samples to process.")
    parser.add_argument("--save_vis", action="store_true", help="Whether to save visualization outputs.")
    parser.add_argument("--prompt", type=str, default="There are two images side by side. The right image is the next step of the left image in the painting process of a painting. Please tell me what is added to right image? The answer should be less than 2 words.", help="Prompt for the model.")

    args = parser.parse_args()
    main(args)


================================================
FILE: data_processing/run_llava/make_list.py
================================================
import json
import os
import cv2
import numpy as np
import tqdm
from pathlib import Path
import argparse

# Function to retrieve all .json files recursively
def get_all_files(src_dir, extension="*.json"):
    src_dir = Path(src_dir)
    all_files = list(src_dir.rglob(extension))
    return [str(file) for file in all_files]  # Convert to string

# Argument parser
parser = argparse.ArgumentParser(description="Process JSON files to create concatenated images and generate a conversation dataset.")
parser.add_argument("--split", type=str, default="train", help="Dataset split to process (e.g., train, val, test).")
parser.add_argument("--root_dir", type=str, default="../../data/sample_data", help="Root directory containing the dataset.")
parser.add_argument("--prompt", type=str, default="There are two images side by side. The left image is an intermediate stage in a painting process of the right image. Please tell me what content should be painted next? The answer should be less than 2 words.", 
                    help="Prompt to include in the conversations.")
args = parser.parse_args()

# Use parsed arguments
split = args.split
root_dir = args.root_dir
prompt = args.prompt

# Get all JSON files
json_files = get_all_files(f'{root_dir}/{split}/text')
json_files = [json_file.replace('/white_', '/0_') for json_file in json_files]
json_files = sorted(json_files, key=lambda x: (x.split('/')[-2], int(x.split('/')[-1].split('_')[0])))
json_files = [json_file.replace('/0_', '/white_') for json_file in json_files]

# Initialize storage for the new JSON data
json_info = []
cnt = 0

# Process each JSON file
for json_file in tqdm.tqdm(json_files):
    with open(json_file, 'r') as f:
        data = json.load(f)

    json_name = json_file.split('/')[-1]
    cur_img_name = data['cur_image_name']
    ref_img_name = data['ref_img_name']
    next_text_corrected = data['next_text']

    # Image paths
    cur_img_path = json_file.replace('/text/', '/rgb/').replace(json_name, cur_img_name + '.jpg')
    ref_img_path = json_file.replace('/text/', '/rgb/').replace(json_name, ref_img_name + '.jpg')

    # Load images
    ref_img = cv2.imread(ref_img_path)
    if 'white' in cur_img_name:
        cur_img = 255 * np.ones_like(ref_img, dtype=np.uint8)
    else:
        cur_img = cv2.imread(cur_img_path)

    # Concatenate images horizontally
    query_image = cv2.hconcat([cur_img, ref_img])

    # Save concatenated image
    
    dst_dir = os.path.dirname(json_file).replace('/text/', '/llava_image/')
    os.makedirs(dst_dir, exist_ok=True)
    dst_path = os.path.join(dst_dir, f'{cur_img_name}_{ref_img_name}.png')
    cv2.imwrite(dst_path, query_image)

    json_image_path = dst_path.split('llava_image/')[-1]

    # Create conversation entry
    json_cur_dict = {
        'id': cnt,
        "image": json_image_path,
        "conversations": [
            {
                "from": "human",
                "value": "<image>\n" + prompt
            },
            {
                "from": "gpt",
                "value": next_text_corrected
            },
        ]
    }

    cnt += 1
    json_info.append(json_cur_dict)

# Write all conversation data to a single JSON file
dst_json_path = f'{root_dir}/{split}/llava_json.json'
os.makedirs(os.path.dirname(dst_json_path), exist_ok=True)
with open(dst_json_path, 'w') as f:
    json.dump(json_info, f, indent=4)

print(f"Generated JSON saved to {dst_json_path}")


================================================
FILE: data_processing/run_llava/utils.py
================================================
import argparse
import torch
from PIL import Image, ImageDraw, ImageFont
import os
from tqdm import tqdm
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
import re


def image_parser(args):
    out = args.image_file.split(args.sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out


class Predictor:
    def __init__(self, args) -> None:
        disable_torch_init()
        model_name = get_model_name_from_path(args.model_path)
        # print(f"Loading model {model_name} from {args.model_path}")
        # print(f"Using model base {args.model_base}")
        # exit()
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            args.model_path, args.model_base, model_name
        )
        self.args = args
        self.model = model
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.context_len = context_len
        self.model_name = model_name

    
    def set_args(self, args):
        self.args = args

    def eval_model(self):

        args = self.args
        model = self.model
        tokenizer = self.tokenizer
        image_processor = self.image_processor
        model_name = self.model_name
        
        # Model


        qs = args.query
        image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
        if IMAGE_PLACEHOLDER in qs:
            if model.config.mm_use_im_start_end:
                qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
            else:
                qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
        else:
            if model.config.mm_use_im_start_end:
                qs = image_token_se + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

        if "llama-2" in model_name.lower():
            conv_mode = "llava_llama_2"
        elif "mistral" in model_name.lower():
            conv_mode = "mistral_instruct"
        elif "v1.6-34b" in model_name.lower():
            conv_mode = "chatml_direct"
        elif "v1" in model_name.lower():
            conv_mode = "llava_v1"
        elif "mpt" in model_name.lower():
            conv_mode = "mpt"
        else:
            conv_mode = "llava_v0"

        if args.conv_mode is not None and conv_mode != args.conv_mode:
            print(
                "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                    conv_mode, args.conv_mode, args.conv_mode
                )
            )
        else:
            args.conv_mode = conv_mode

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image_files = image_parser(args)
        images = load_images(image_files)
        image_sizes = [x.size for x in images]
        images_tensor = process_images(
            images,
            image_processor,
            model.config
        ).to(model.device, dtype=torch.float16)

        # print(images_tensor.shape)

        input_ids = (
            tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
            .unsqueeze(0)
            .cuda()
        )

        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=images_tensor,
                image_sizes=image_sizes,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                max_new_tokens=args.max_new_tokens,
                use_cache=True,
            )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        return outputs






if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--query", type=str, required=True)
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--sep", type=str, default=",")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=512)
    args = parser.parse_args()

    eval_model(args)


================================================
FILE: data_processing/run_lpips/main.py
================================================
import glob
import os
import json
import numpy as np
from PIL import Image
import tqdm
import matplotlib.pyplot as plt
import lpips
import torch
import cv2
import argparse

# Argument parser
parser = argparse.ArgumentParser(description="Compute LPIPS and visualize.")
parser.add_argument("--datadata_folder", type=str, default="../../data/sample_data",
                    help="Path to the data folder containing videos.")
parser.add_argument("--split", type=str, default="train",
                    help="Dataset split to process (e.g., train, val, test).")
parser.add_argument("--sample_num", type=int, default=1000,
                    help="Number of frames to sample.")
parser.add_argument("--save_vis", action="store_true", help="Whether to save visualization outputs.")


args = parser.parse_args()

# Initialize LPIPS
lpips_fn_alex = lpips.LPIPS(net='alex', spatial=True).cuda()

datadata_folder = args.datadata_folder
split = args.split
sample_num = args.sample_num
save_vis = args.save_vis

cache_dir = f'cache'
os.makedirs(cache_dir, exist_ok=True)

video_dirs = glob.glob(f'{datadata_folder}/{split}/rgb/*')
video_dirs = sorted(video_dirs)
print(video_dirs)

dst_dir = f'{datadata_folder}/{split}/lpips'
dst_vis_dir = f'{datadata_folder}/{split}/lpips_vis'

for video_dir in tqdm.tqdm(video_dirs):
    last_aligned_frame_inv_path = f'{video_dir}/last_aligned_frame_inv.json'
    video_name = video_dir.split('/')[-1]

    # Load JSON
    with open(last_aligned_frame_inv_path) as f:
        last_aligned_frame_inv_path_dict = json.load(f)

    for ref_img_name in list(last_aligned_frame_inv_path_dict.keys()):
        canvas_candidate_list_full = last_aligned_frame_inv_path_dict[ref_img_name]
        ref_image = Image.open(f'{video_dir}/{ref_img_name}.jpg')
        canvas_candidate_list_full = canvas_candidate_list_full[1:]

        canvas_candidate_list = ['white']
        if len(canvas_candidate_list_full) >= (sample_num - 1):
            sample_inds = np.round(np.linspace(0, len(canvas_candidate_list_full) - 1, sample_num - 1)).astype(int)
            canvas_candidate_list = canvas_candidate_list + [canvas_candidate_list_full[i] for i in sample_inds]
            canvas_candidate_list.append(ref_img_name)
            assert len(canvas_candidate_list) == (sample_num + 1)
        else:
            canvas_candidate_list = canvas_candidate_list + canvas_candidate_list_full
            canvas_candidate_list.append(ref_img_name)

        for i in range(len(canvas_candidate_list[:-1])):
            cur_image_name = canvas_candidate_list[i]
            next_image_name = canvas_candidate_list[i + 1]

            if cur_image_name == 'white':
                cur_image = Image.new('RGB', (ref_image.width, ref_image.height), (255, 255, 255))
                cur_image_name = f'white_{ref_img_name}'
            else:
                cur_image_path = f'{video_dir}/{cur_image_name}.jpg'
                cur_image = Image.open(cur_image_path)

            next_image_path = f'{video_dir}/{next_image_name}.jpg'
            next_image = Image.open(next_image_path)

            # Convert to torch tensors
            cur_image = torch.tensor(np.array(cur_image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
            next_image = torch.tensor(np.array(next_image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0

            # Normalize to [-1, 1]
            cur_image = cur_image * 2 - 1
            next_image = next_image * 2 - 1

            # Move to GPU
            cur_image = cur_image.cuda()
            next_image = next_image.cuda()

            # Compute LPIPS
            lpips_out = lpips_fn_alex(cur_image, next_image)
            lpips_out = torch.clamp(lpips_out, 0, 1)
            lpips_mask = lpips_out.cpu().detach().numpy()

            # Save LPIPS mask
            dst_path = f'{dst_dir}/{video_name}'
            os.makedirs(dst_path, exist_ok=True)
            cv2.imwrite(f'{dst_path}/{cur_image_name}.jpg', (lpips_mask[0] * 255).astype(np.uint8).transpose(1, 2, 0))

            if save_vis:
                # Visualize LPIPS mask
                cur_image = (cur_image + 1) / 2
                next_image = (next_image + 1) / 2

                lpips_mask = lpips_mask[0]
                lpips_mask = (lpips_mask * 255).astype(np.uint8).transpose(1, 2, 0)
                next_image = next_image[0].cpu().detach().numpy()
                next_image = (next_image * 255).astype(np.uint8).transpose(1, 2, 0)
                cur_image = cur_image[0].cpu().detach().numpy()
                cur_image = (cur_image * 255).astype(np.uint8).transpose(1, 2, 0)

                fig, (ax1, ax2) = plt.subplots(1, 2)
                ax1.imshow(cur_image)
                ax1.axis('off')
                ax2.imshow(next_image)
                ax2.imshow(lpips_mask[:, :, 0], cmap='hot', alpha=0.5, interpolation='bilinear')
                ax2.axis('off')

                os.makedirs(f'{dst_vis_dir}/{video_name}', exist_ok=True)
                plt.savefig(f'{dst_vis_dir}/{video_name}/{cur_image_name}.jpg', bbox_inches='tight')
                plt.close()


================================================
FILE: dataset/dataset.py
================================================
import os, io, csv, math, random
import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from transformers import CLIPProcessor
import glob
import json


# adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/data/dataset.py

import torch.distributed as dist
def zero_rank_print(s):
    if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)

def load_im_as_tensor(im_paths):

    if isinstance(im_paths, list):
        im = [Image.open(im_path) for im_path in im_paths]
        im = [im.convert('RGB') for im in im]
        im = [np.array(im) for im in im]
        im = [torch.from_numpy(im).permute(2, 0, 1).contiguous().float()[None] for im in im]
        im = [im / 255.0 for im in im]
        im = torch.cat(im, dim=0)
    else:
        im = Image.open(im_paths)
        im = im.convert('RGB')
        im = np.array(im)
        im = torch.from_numpy(im).permute(2, 0, 1).contiguous().float()[None]
        im = im / 255.0
    return im
        
    

class InvPaintingDataset(Dataset):
    def __init__(
            self,
            data_folder,
            sample_size=768, sample_stride=4, sample_n_frames=24, sample_num=20, 
            is_image=False, clip_model_path="openai/clip-vit-base-patch32",
            is_train=True,
            pad_mode=False, 
            PE_type=None, 
            PE_time_interval=None, 
            PE_time_max=None,
            win_size=None,
        ):
        zero_rank_print(f"loading annotations from {data_folder} ...")


        self.is_train = is_train
        self.spilt = 'train' if self.is_train else 'val'

        self.sample_size = sample_size
        self.pad_mode= pad_mode
        self.PE_type = PE_type
        self.PE_time_max= PE_time_max
        self.PE_time_interval = PE_time_interval
        self.win_size = win_size
    

        self.video_dirs = glob.glob(f'{data_folder}/{self.spilt}/rgb/*')
        
        self.length = len(self.video_dirs)
        zero_rank_print(f"video nums: {self.length}")
        print(f"video nums: {self.length}")

        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        self.is_image        = is_image
        self.sample_num = sample_num
        
        self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=True)
 
        self.pixel_transforms = transforms.Compose([
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])

    def __len__(self):
        return self.length
    
    def get_batch(self,idx):
        video_dir = self.video_dirs[idx]
        video_name = video_dir.split('/')[-1]

        last_aligned_frame_inv_path = f'{video_dir}/last_aligned_frame_inv.json'

        # load json 
        with open(last_aligned_frame_inv_path) as f:
            last_aligned_frame_inv_path_dict = json.load(f)
            
        # first sample the ref image  (final canvas) randomly from  key of the last_aligned_frame_inv_path_dict
        ref_img_name = random.choice(list(last_aligned_frame_inv_path_dict.keys()))


        canvas_candidate_list_full = last_aligned_frame_inv_path_dict[ref_img_name]
        

        if len(canvas_candidate_list_full) == 1:
            return None, None, None, None


        # compute time difference 
        canvas_candidate_list_full_time = canvas_candidate_list_full + [ref_img_name]
        canvas_candidate_list_full_time = [int(name.split('_')[-1].split(':')[0]) * 60 + int(name.split('_')[-1].split(':')[1]) for name in canvas_candidate_list_full_time]
        

        # make the first time to be 5
        if canvas_candidate_list_full_time[0] > 30:
            canvas_candidate_list_full_time[0] = 30



        canvas_candidate_list_full_time = [canvas_candidate_list_full_time[i+1] - canvas_candidate_list_full_time[i] for i in range(len(canvas_candidate_list_full_time) - 1)]

        
        # first frame as white image, last frame as final canvas
        canvas_candidate_list = ['white']
        # remove first frame (likely white image)
        canvas_candidate_list_full = canvas_candidate_list_full[1:]

        

        
        if len(canvas_candidate_list_full) >= (self.sample_num - 1):
            # uniformly sample 10 frames from canvas_candidate_list
            sample_inds = np.round(np.linspace(0, len(canvas_candidate_list_full) - 1, self.sample_num - 1)).astype(int)
            canvas_candidate_list = canvas_candidate_list + [canvas_candidate_list_full[i] for i in sample_inds]

            # add last frame
            canvas_candidate_list.append(ref_img_name)
        

            assert len(canvas_candidate_list) == (self.sample_num + 1)
        else:
            # add last frame
            canvas_candidate_list = canvas_candidate_list + canvas_candidate_list_full
            canvas_candidate_list.append(ref_img_name)




        ref_img_path = f'{video_dir}/{ref_img_name}.jpg'
        
        pixel_values_ref_img = load_im_as_tensor(ref_img_path)

        ref_img_pil = Image.open(ref_img_path)
        clip_ref_image = self.clip_image_processor(images=ref_img_pil, return_tensors="pt").pixel_values



        # then sample current image 
        cur_img_ind = np.random.choice(len(canvas_candidate_list[:-1]))
        cur_img_name = canvas_candidate_list[cur_img_ind]

        if self.PE_type == 'rel':
            assert False
            cur_img_pos = cur_img_ind / (len(canvas_candidate_list) - 1)
        elif self.PE_type == 'abs':
            cur_img_pos = canvas_candidate_list_full_time[cur_img_ind]

            if  self.PE_time_interval > 1:
                # make the cur_img_pos to be cloestest multiple of PE_time_interval, +1 is to avoid 0 
                cur_img_pos = round(cur_img_pos / self.PE_time_interval) * self.PE_time_interval 


            # print(f'cur_img_pos: {cur_img_pos}')
            cur_img_pos = cur_img_pos / self.PE_time_max

            # clamp to 0 to 1 
            cur_img_pos = max(0, min(1, cur_img_pos))

            
        #to [-1, 1]
        cur_img_pos = cur_img_pos * 2 - 1
        

        if cur_img_name == 'white':
            pixel_values_cur = torch.ones_like(pixel_values_ref_img)
            cur_img_path = 'white'
            cur_img_pil = Image.new('RGB', (224, 224), (255, 255, 255))

        else:
            cur_img_path = f'{video_dir}/{cur_img_name}.jpg'
            pixel_values_cur = load_im_as_tensor(cur_img_path)
            cur_img_pil = Image.open(cur_img_path)

   




        clip_cur_image = self.clip_image_processor(images=cur_img_pil, return_tensors="pt").pixel_values

        # get next canvas
        gt_img_name = canvas_candidate_list[cur_img_ind + 1]
        gt_img_path = f'{video_dir}/{gt_img_name}.jpg'
        pixel_values = load_im_as_tensor(gt_img_path)

        # # to [-1, 1]
        next_img_pos = cur_img_pos


        # if not  pixel_values pixel_values_cur pixel_values_ref_img should have same shape, return 
        if not (pixel_values.shape == pixel_values_cur.shape == pixel_values_ref_img.shape):
            return None, None, None, None
        
        assert pixel_values.shape == pixel_values_cur.shape == pixel_values_ref_img.shape



        if self.pad_mode == 'pad_to_square+resize':
            # Pad to square image
            max_size = max(pixel_values.shape[2], pixel_values.shape[3])
            pixel_values = torch.nn.functional.pad(pixel_values, (0, max_size - pixel_values.shape[3], 0, max_size - pixel_values.shape[2]))
            pixel_values_cur = torch.nn.functional.pad(pixel_values_cur, (0, max_size - pixel_values_cur.shape[3], 0, max_size - pixel_values_cur.shape[2]))
            pixel_values_ref_img = torch.nn.functional.pad(pixel_values_ref_img, (0, max_size - pixel_values_ref_img.shape[3], 0, max_size - pixel_values_ref_img.shape[2]))

            # resize to sample_size
            pixel_values = torch.nn.functional.interpolate(pixel_values, size=(self.sample_size, self.sample_size), mode='bilinear')
            pixel_values_cur = torch.nn.functional.interpolate(pixel_values_cur, size=(self.sample_size, self.sample_size), mode='bilinear')
            pixel_values_ref_img = torch.nn.functional.interpolate(pixel_values_ref_img, size=(self.sample_size, self.sample_size), mode='bilinear')


        if self.pad_mode == 'pad_to_16':

            # pad the border of them to make it multiple of 16, shape is [1, 3, H, W]
            pad_size = [16 - pixel_values.shape[2] % 16, 16 - pixel_values.shape[3] % 16]

            # if pad_size[0] != 16 or pad_size[1] != 16:
            if pad_size[0] == 16:
                pad_size[0] = 0
            if pad_size[1] == 16:
                pad_size[1] = 0
            # pad the border of them to make it multiple of 16, shape is [1, 3, H, W]
            pixel_values = torch.nn.functional.pad(pixel_values, (0, pad_size[1], 0, pad_size[0]))
            pixel_values_cur = torch.nn.functional.pad(pixel_values_cur, (0, pad_size[1], 0, pad_size[0]))
            pixel_values_ref_img = torch.nn.functional.pad(pixel_values_ref_img, (0, pad_size[1], 0, pad_size[0]))

            
 
        if self.is_image:
            pixel_values = pixel_values[0]
            pixel_values_cur = pixel_values_cur[0]
        
        pixel_values_ref_img = pixel_values_ref_img[0]
        

        return pixel_values, pixel_values_cur, clip_ref_image, clip_cur_image, pixel_values_ref_img, cur_img_pos, next_img_pos, cur_img_path, gt_img_path, ref_img_path
    
    def __getitem__(self, idx):
        while True:
            idx = random.randint(0, self.length-1) 
            try:
                pixel_values, pixel_values_cur, clip_ref_image, clip_cur_image, pixel_values_ref_img, cur_img_pos, next_img_pos, cur_img_path, gt_img_path, ref_img_path = self.get_batch(idx)
                if pixel_values is not None:
                    break
            except Exception as e:
                print('exception!!!!!!')
                pass
                
      
            

        pixel_values = self.pixel_transforms(pixel_values)
        pixel_values_cur = self.pixel_transforms(pixel_values_cur)
        pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
        pixel_values_ref_img = self.pixel_transforms(pixel_values_ref_img)
        pixel_values_ref_img = pixel_values_ref_img.squeeze(0)
        cur_img_pos, next_img_pos = torch.tensor(cur_img_pos).unsqueeze(0), torch.tensor(next_img_pos).unsqueeze(0)

        
        
        # clip_ref_image = clip_ref_image.unsqueeze(1) # [bs,1,768]
        drop_image_embeds = 1 if random.random() < 0.1 else 0
        drop_time_step = 1 if random.random() < 0.1 else 0
        drop_feature = 1 if random.random() < 0.1 else 0
        drop_cur_cond = 1 if random.random() < 0.1 else 0
        drop_RP = 1 if random.random() < 0.1 else 0
        sample = dict(
            pixel_values=pixel_values, 
            pixel_values_cur=pixel_values_cur,
            clip_ref_image=clip_ref_image,
            clip_cur_image=clip_cur_image,
            pixel_values_ref_img=pixel_values_ref_img,
            drop_image_embeds=drop_image_embeds,
            drop_time_step = drop_time_step,
            drop_feature = drop_feature,
            drop_cur_cond=drop_cur_cond,
            drop_RP=drop_RP,
            cur_img_pos=cur_img_pos,
            next_img_pos=next_img_pos,
            cur_img_path=cur_img_path,
            gt_img_path=gt_img_path,
            ref_img_path=ref_img_path

            
            )
        
        return sample




# https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py#L341

def collate_fn(data):

    
    pixel_values = torch.stack([example["pixel_values"] for example in data])
    pixel_values_cur = torch.stack([example["pixel_values_cur"] for example in data])
    clip_ref_image = torch.cat([example["clip_ref_image"] for example in data])
    clip_cur_image = torch.cat([example["clip_cur_image"] for example in data])
    pixel_values_ref_img = torch.stack([example["pixel_values_ref_img"] for example in data])
    cur_img_pos = torch.stack([example["cur_img_pos"] for example in data])
    next_img_pos = torch.stack([example["next_img_pos"] for example in data])

    drop_image_embeds = [example["drop_image_embeds"] for example in data]
    drop_image_embeds = torch.Tensor(drop_image_embeds)

    drop_time_step = [example["drop_time_step"] for example in data]
    drop_time_step = torch.Tensor(drop_time_step)

    drop_feature = [example["drop_feature"] for example in data]
    drop_feature = torch.Tensor(drop_feature)

    drop_cur_cond = [example["drop_cur_cond"] for example in data]
    drop_cur_cond = torch.Tensor(drop_cur_cond)

    drop_RP = [example["drop_RP"] for example in data]
    drop_RP = torch.Tensor(drop_RP)


    
    return {
        "pixel_values": pixel_values,
        "pixel_values_cur": pixel_values_cur,
        "clip_ref_image": clip_ref_image,
        "clip_cur_image": clip_cur_image,
        "pixel_values_ref_img": pixel_values_ref_img,
        "drop_image_embeds": drop_image_embeds,
        "drop_time_step": drop_time_step,
        "drop_feature": drop_feature,
        "drop_cur_cond": drop_cur_cond,
        "drop_RP": drop_RP,
        "cur_img_pos": cur_img_pos,
        "next_img_pos": next_img_pos,
        "cur_img_path": [example["cur_img_path"] for example in data],
        "gt_img_path": [example["gt_img_path"] for example in data],
        "ref_img_path": [example["ref_img_path"] for example in data]
        
    }



================================================
FILE: demo.py
================================================
import argparse
import datetime
import os
import random
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from omegaconf import OmegaConf
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.distributed as dist
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import torch.nn.functional as F
from utils.dist_tools import distributed_init
from utils.inference_helpers import *
import lpips

def parse_args():
    parser = argparse.ArgumentParser(description="Run inference with specified configuration.")
    parser.add_argument("--ckpt_path", type=str, default='checkpoints/renderer/ckpt/checkpoint-global_step-200000.ckpt', help="Path to renderer checkpoint.")
    parser.add_argument("--RP_path", type=str, default='./checkpoints/RP/checkpoint-global_step-80000.ckpt', help="Path to LLaVA model checkpoint.")
    parser.add_argument("--output_dir", type=str, default='./results', help="Path to the output directory.")
    parser.add_argument("--llava_path", type=str, default='checkpoints/TP_llava', help="Path to LLaVA model checkpoint.")
    parser.add_argument("--test_dir", type=str, default='./data/demo', help="Path to the directory containing test images.")
    parser.add_argument("--random_seeds", type=int, nargs='+', default=[1], help="List of random seeds for inference.")
    parser.add_argument("--num_actual_inference_steps", type=int, default=50, help="Number of actual inference steps.")
    parser.add_argument("--steps", type=int, default=25, help="Number of steps.")
    parser.add_argument("--guidance_scale", type=float, default=2.0, help="Guidance scale for inference.")
    parser.add_argument("--TP_guidance_scale", type=float, default=5.0, help="TP guidance scale.")
    parser.add_argument("--cur_guidance_scale", type=float, default=1.0, help="Current image guidance scale.")
    parser.add_argument("--RP_guidance_scale", type=float, default=5.0, help="RP guidance scale.")
    parser.add_argument("--PE_guidance_scale", type=float, default=5.0, help="PE guidance scale.")
    parser.add_argument("--dilate_RP", type=bool, default=True, help="Dilate RP or not.")
    parser.add_argument("--PE_sec", type=int, default=20, help="PE section.")
    parser.add_argument("--total_step", type=int, default=50, help="Total steps.")
    parser.add_argument("--binary_threshold", type=float, default=0.2, help="Binary threshold.")
    parser.add_argument("--combine_init", type=bool, default=True, help="Combine initial image.")
    parser.add_argument("--combine_init_ratio", type=float, default=0.2, help="Ratio to combine initial image.")
    parser.add_argument("--split", type=str, default='test', help="Data split.")
    parser.add_argument("--cur_alpha", type=float, default=0.0, help="Current alpha value.")
    parser.add_argument("--pretrained_model_path", type=str, default="base_ckpt/realisticVisionV51_v51VAE", help="Path to pretrained model.")
    parser.add_argument("--pretrained_clip_path", type=str, default="./base_ckpt/clip-vit-base-patch32", help="Path to pretrained CLIP model.")
    parser.add_argument("--tmp_cur_img_folder", type=str, default='cache_cur_img', help="Temporary image folder.")
    parser.add_argument("--dist", action="store_true", required=False, help="Enable distributed mode.")
    parser.add_argument("--rank", type=int, default=0, required=False, help="Rank for distributed mode.")
    parser.add_argument("--world_size", type=int, default=1, required=False, help="World size for distributed mode.")
    return parser.parse_args()

def main(args):



    # Load configurations and initialize device
    device = torch.device(f"cuda:{args.rank}")
    dist_kwargs = {"rank": args.rank, "world_size": args.world_size, "dist": args.dist}
    dtype = torch.float16
    config_path = os.path.join(os.path.dirname(args.ckpt_path), '..', 'config.yaml')
    config = OmegaConf.load(config_path)

    # Update config with arguments
    config.update({
        'pretrained_model_path': args.pretrained_model_path,
        'split': args.split,
        'llava_path': args.llava_path,
        'binary': args.binary_threshold > 0,
        'binary_threshold': args.binary_threshold,
        'PE_sec': args.PE_sec,
        'RP_path': args.RP_path,
    })


    # Set up output directory and data paths
    root_dst_dir = prepare_results_dir(config, args.ckpt_path, args.output_dir)
    images_info = get_dataset_info(args.test_dir)
    full_state_dict = torch.load(args.ckpt_path, map_location='cpu')
    
    # get time 
    now = datetime.datetime.now()
    time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
    rand_num = random.randint(0, 100000)


    total_step, guidance_scale, cur_guidance_scale, cur_alpha, PE_guidance_scale = args.total_step, args.guidance_scale, args.cur_guidance_scale, args.cur_alpha, args.PE_guidance_scale
    TP_guidance_scale, RP_guidance_scale, dilate_RP, combine_init, combine_init_ratio = args.TP_guidance_scale, args.RP_guidance_scale, args.dilate_RP, args.combine_init, args.combine_init_ratio
    steps, num_actual_inference_steps = args.steps, args.num_actual_inference_steps
    
    # this is used for saving images for text generator
    tmp_cur_img_path = f'{args.tmp_cur_img_folder}/{time_str}_{rand_num}.png'
    os.makedirs(args.tmp_cur_img_folder, exist_ok=True)
    next_RP_embeddings = None
    next_prompt  = None

    # prepare text generator and mask generator
    TP = TP_wrapper(config, full_state_dict, device, dtype)
    RP = RP_wrapper(config, full_state_dict, device, dtype)
        
    # prepare time embeddings
    PE_sec = config['PE_sec']
    PE = PE_wrapper(config, full_state_dict, device, dtype)
    with torch.no_grad():
        PE_embeddings = PE.embed(PE_sec)
            
    # prepare negative text embeddings
    negative_next_TP_embeddings = TP.get_negative_embeddings()

    # Load inference pipeline and LPIPS for similarity calculations
    pipeline, pipeline_kwargs = load_pipeline(config, args.pretrained_model_path, args.pretrained_clip_path, full_state_dict, dtype, device)
    lpips_fn_alex = lpips.LPIPS(net='alex', spatial=False).to(device)

    print('Start inference')
    for random_seed in args.random_seeds:

        for ref_img_path in images_info:
            seed = 1
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            
            
            gt_next_img_paths = images_info[ref_img_path]
            

            dst_dir = f'{root_dst_dir}/' + ref_img_path.split('/')[-2] + '/' + ref_img_path.split('/')[-1].split('.')[0] + f'/seed_{random_seed}_total_step_{total_step}'
            guidance_name = f'gs+clip{guidance_scale}+cur_alpha{cur_alpha}'

            guidance_name += f'+PE{PE_guidance_scale}'

            guidance_name += f'+TP{TP_guidance_scale}'

            guidance_name += f'+RP{RP_guidance_scale}'
            if combine_init:
                guidance_name += f'_combine{combine_init_ratio}'
                if config['binary']:
                    guidance_name += f'_binary'
                if dilate_RP:
                    guidance_name += f'_dilate'


            dst_dir += f'_{guidance_name}'
            
            if not os.path.exists(dst_dir):
                os.makedirs(dst_dir, exist_ok=True)
            else:
                print(f'{dst_dir} exists, skip')
                continue

            # target image
            ref_img = np.array(Image.open(ref_img_path).convert('RGB'))
            ori_h, ori_w, c = ref_img.shape
            
            plt.imsave(f"{dst_dir}/ori_img.jpg", ref_img)
            ref_img = pad_to_16(ref_img)

            # current image: starting from white canvas
            cur_img = np.ones((ori_h, ori_w, 3)) * 255
            cur_img = pad_to_16(cur_img)
            cur_img = cur_img.astype(np.uint8)
            plt.imsave(f"{dst_dir}/sample_0.jpg", cur_img[:ori_h, :ori_w])

            generator = torch.Generator(device=torch.device("cuda:0"))
            generator.manual_seed(random_seed)
            state = generator.get_state()

            next_RP_embeddings_prev = None

            cur_next_diffs = []
            next_ref_diffs = []        
            for idx in tqdm(range( total_step)):
                generator.set_state(state)
        
                generator_RP = torch.Generator(device=torch.device("cuda:0"))
                generator_RP.manual_seed(random_seed + idx)


                H, W, C = ref_img.shape

                # determine whether to stop, based on the last two differences
                if len(cur_next_diffs) > 3 and cur_next_diffs[-2] < 1e-3 and cur_next_diffs[-1] < 1e-3:
                    break

                # determine whether to stop, based on difference between next and reference
                if len(next_ref_diffs) > 0 and next_ref_diffs[-1] < 1e-1:
                    break

                kwargs = {}
                # copy dist_kwargs
                kwargs.update(dist_kwargs)  
                kwargs.update(pipeline_kwargs)              
                
                kwargs['use_PE'] = config['use_PE']
                kwargs['PE_guidance_scale'] = PE_guidance_scale
                kwargs['PE_embeddings'] = PE_embeddings
                kwargs['negative_PE_embeddings'] = torch.zeros_like(PE_embeddings)
            
                        
                cur_img_path = tmp_cur_img_path 
                plt.imsave(cur_img_path, cur_img)
                
                cache_path = cur_img_path.replace('.png', '_.png')
                next_text_embeddings, next_prompt = TP(cur_img_path, ref_img_path, cache_path=cache_path)
            
                
                kwargs['TP_feature'] = next_text_embeddings
                kwargs['use_TP'] = config['use_TP']
                kwargs['TP_guidance_scale'] = TP_guidance_scale
                kwargs['negative_TP_feature'] = negative_next_TP_embeddings


                if idx == 0:
                    cur_img_path = 'white'

                cur_img_path = tmp_cur_img_path 
                plt.imsave(cur_img_path, cur_img)


                ##### for mask generation #####

                # the predicted mask in the previous step 
                if next_RP_embeddings is not None:
                    next_RP_embeddings_prev = next_RP_embeddings.clone().to(torch.float32)

                
                if dilate_RP:
                    # This is a trick to make the RP more robust, ensuring the generated mask is not too small. 
                    # We don't use 0.5 as the threshold, but try different thresholds
                
                    threshold_list = [ 0.5, 0.4, 0.3, 0.2, 0.1]
                    for threshold in threshold_list:
                        next_RP_embeddings, input_RP_embeddings_diff = RP(cur_img_path, ref_img_path, next_prompt=next_prompt, next_RP_embeddings_prev=None, PE_sec=PE_sec, generator=generator_RP, threshold=threshold)

                        if idx == 0:
                            break

                        next_RP_embeddings_sum = next_RP_embeddings.sum()

                        if next_RP_embeddings_sum < int(H * W * 0.05):
                            print(f'Warning: next_RP_embeddings is too small: {next_RP_embeddings_sum}, change to {threshold}')
                            continue 
                        
                        # compute iou 
                        iou = (next_RP_embeddings * next_RP_embeddings_prev).sum() / ((next_RP_embeddings + next_RP_embeddings_prev) > 0).sum()
                        if iou < 0.8:
                            break
                        else:
                            sum_diff = next_RP_embeddings.float().sum() - next_RP_embeddings_prev.float().sum()
                            print(f'Warning: iou {iou} is too high, sum_diff {sum_diff}, change to {threshold}')

                else:
                    next_RP_embeddings, input_RP_embeddings_diff = RP(cur_img_path, ref_img_path, next_prompt=next_prompt, next_RP_embeddings_prev=None, PE_sec=PE_sec, generator=generator_RP, threshold=0.5)
                    

                kwargs['RP_guidance_scale'] = RP_guidance_scale
                kwargs['RP_embeddings'] = next_RP_embeddings.to(dtype)
                kwargs['negative_RP_embeddings'] = torch.zeros_like(next_RP_embeddings)

                if combine_init:
                    if idx > 0:
                        kwargs['combine_init'] = combine_init
                        kwargs['combine_init_ratio'] = combine_init_ratio
                        kwargs['img_init_latents'] = pred_next_latents


    
      
                generator = generator.set_state(state)
                outputs = pipeline(
                    num_inference_steps     = steps,
                    guidance_scale          = guidance_scale,
                    cur_guidance_scale      = cur_guidance_scale, 
                    width                   = W,
                    height                  = H,
                    generator               = generator,
                    num_actual_inference_steps = num_actual_inference_steps,
                    source_image            = ref_img,
                    cur_condition           = cur_img,
                    cur_alpha               = cur_alpha,
                    **kwargs,
                )


                pred_next_img = outputs.images
                pred_next_latents = outputs.latents


                
                # save sample torch tensor (1, H, W, 3)
                pred_next_img = pred_next_img[0]
                pred_next_img = pred_next_img.cpu().numpy()
                pred_next_img = np.clip(pred_next_img * 255, 0, 255).astype(np.uint8)

        
                cur_img_tensor = torch.tensor(cur_img).permute(2, 0, 1).unsqueeze(0).to(dtype).to(device)[:ori_h, :ori_w, :]
                pred_next_img_tensor = torch.tensor(pred_next_img).permute(2, 0, 1).unsqueeze(0).to(dtype).to(device)[:ori_h, :ori_w, :]
                ref_img_tensor = torch.tensor(ref_img).permute(2, 0, 1).unsqueeze(0).to(dtype).to(device)[:ori_h, :ori_w, :]

                cur_img_tensor = (cur_img_tensor / 255.)  * 2 - 1
                pred_next_img_tensor = (pred_next_img_tensor / 255.)  * 2 - 1
                ref_img_tensor = (ref_img_tensor / 255.)  * 2 - 1

                # difference between current and next, next and reference
                cur_next_diff = lpips_fn_alex(cur_img_tensor.cuda(), pred_next_img_tensor.cuda()).item()
                next_ref_diff = lpips_fn_alex(ref_img_tensor.cuda(), pred_next_img_tensor.cuda()).item()

                cur_next_diffs.append(cur_next_diff)
                next_ref_diffs.append(next_ref_diff)

                # Visualization
                next_RP_embeddings_vis = next_RP_embeddings.cpu().detach().numpy()
                next_RP_embeddings_vis = np.clip(next_RP_embeddings_vis * 255, 0, 255).astype(np.uint8)
                next_RP_embeddings_vis = next_RP_embeddings_vis[0,0]
                next_RP_embeddings_vis = next_RP_embeddings_vis[..., None]
                next_RP_embeddings_vis = np.concatenate([next_RP_embeddings_vis, next_RP_embeddings_vis, next_RP_embeddings_vis], axis=2)
                next_RP_embeddings_vis = next_RP_embeddings_vis[:ori_h, :ori_w, :]


                next_RP_embeddings_vis = Image.fromarray(next_RP_embeddings_vis)
                draw = ImageDraw.Draw(next_RP_embeddings_vis)
                font = ImageFont.truetype("utils/arial.ttf", 40)
                draw.text((10, 10), next_prompt, (255, 0, 0), font=font)
                next_RP_embeddings_vis = np.array(next_RP_embeddings_vis)

                next_RP_embeddings_vis = np.concatenate([pred_next_img[:ori_h, :ori_w], next_RP_embeddings_vis], axis=1)
                plt.imsave(f"{dst_dir}/vis_sample_{idx+1}.jpg", next_RP_embeddings_vis)
                plt.imsave(f"{dst_dir}/sample_{idx+1}.jpg", pred_next_img[:ori_h, :ori_w, :])



                cur_img = pred_next_img

                # post processing for cur_img, pad to multiple of 16
                cur_img = cur_img[:ori_h, :ori_w, :]
                cur_img = pad_to_16(cur_img)
                cur_img = cur_img.astype(np.uint8)

                assert cur_img.shape[0] == ref_img.shape[0] and cur_img.shape[1] == ref_img.shape[1]



if __name__ == "__main__":

    args = parse_args()
    main(args)


================================================
FILE: models/ReferenceEncoder.py
================================================
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPProcessor, CLIPVisionModel, CLIPImageProcessor
from transformers import logging
logging.set_verbosity_warning()
logging.set_verbosity_error()

# https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train_plus.py#L49

class ReferenceEncoder(nn.Module):
    def __init__(self, model_path="openai/clip-vit-base-patch32"):
        super(ReferenceEncoder, self).__init__()
        self.model = CLIPVisionModel.from_pretrained(model_path,local_files_only=True)
        self.freeze()

    def freeze(self):
        self.model = self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, pixel_values):
        outputs = self.model(pixel_values)
        pooled_output = outputs.pooler_output
        return pooled_output




# class ReferenceEncoder(nn.Module):
#     def __init__(self, model_path="openai/clip-vit-base-patch32"):
#         super(ReferenceEncoder, self).__init__()
#         self.model = CLIPVisionModel.from_pretrained(model_path,local_files_only=True)
#         self.processor = CLIPProcessor.from_pretrained(model_path,local_files_only=True)
#         self.freeze()

#     def freeze(self):
#         self.model = self.model.eval()
#         for param in self.model.parameters():
#             param.requires_grad = False

#     def forward(self, image):
#         inputs = self.processor(images=image, return_tensors="pt")
        
#         print(inputs['pixel_values'].size())
        
#         outputs = self.model(**inputs)
        
#         pooled_output = outputs.pooler_output

#         return pooled_output

# # example
# model = ReferenceEncoder()
# image_path = "../../000000039769.jpg"
# image_path = "/mnt/f/research/HumanVideo/AnimateAnyone-unofficial/DWPose/0001.png"
# image = Image.open(image_path).convert('RGB')
# image = [image,image]

# pooled_output = model(image)

# print(f"Pooled Output Size: {pooled_output.size()}") # Pooled Output Size: torch.Size([bs, 768])


================================================
FILE: models/ReferenceNet.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import os
import torch
import torch.nn as nn
import torch.utils.checkpoint

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.models.embeddings import (
    GaussianFourierProjection,
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
    PositionNet,
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
    UNetMidBlock2DCrossAttn,
    UNetMidBlock2DSimpleCrossAttn,
    get_down_block,
    get_up_block,
)


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class Identity(torch.nn.Module):
    r"""A placeholder identity operator that is argument-insensitive.

    Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Shape:
        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
        - Output: :math:`(*)`, same shape as the input.

    Examples::

        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 20])

    """
    def __init__(self, scale=None, *args, **kwargs) -> None:
        super(Identity, self).__init__()

    def forward(self, input, *args, **kwargs):
        return input



class _LoRACompatibleLinear(nn.Module):
    """
    A Linear layer that can be used with LoRA.
    """

    def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.lora_layer = lora_layer

    def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
        self.lora_layer = lora_layer

    def _fuse_lora(self):
        pass

    def _unfuse_lora(self):
        pass

    def forward(self, hidden_states, scale=None, lora_scale: int = 1):
        return hidden_states


@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    The output of [`UNet2DConditionModel`].

    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor = None


class ReferenceNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
    shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).

    Parameters:
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
            The tuple of upsample blocks to use.
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
            If `None`, normalization and activation layers is skipped in post-processing.
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
        encoder_hid_dim (`int`, *optional*, defaults to None):
            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
            dimension to `cross_attention_dim`.
        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
        num_attention_heads (`int`, *optional*):
            The number of attention heads. If not defined, defaults to `attention_head_dim`
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to `None`):
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
        addition_embed_type (`str`, *optional*, defaults to `None`):
            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
            "text". "text" will use the `TextTimeEmbedding` layer.
        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
            Dimension for the timestep embeddings.
        num_class_embeds (`int`, *optional*, defaults to `None`):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
        time_embedding_type (`str`, *optional*, defaults to `positional`):
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
        time_embedding_dim (`int`, *optional*, defaults to `None`):
            An optional override for the dimension of the projected time embedding.
        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
            Optional activation function to use only once on the time embeddings before they are passed to the rest of
            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
        timestep_post_act (`str`, *optional*, defaults to `None`):
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
            `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
            embeddings with the class embeddings.
        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
            otherwise.
    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
        only_cross_attention: Union[bool, Tuple[bool]] = False,
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        layers_per_block: Union[int, Tuple[int]] = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
        norm_num_groups: Optional[int] = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        encoder_hid_dim: Optional[int] = None,
        encoder_hid_dim_type: Optional[str] = None,
        attention_head_dim: Union[int, Tuple[int]] = 8,
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
        dual_cross_attention: bool = False,
        use_linear_projection: bool = False,
        class_embed_type: Optional[str] = None,
        addition_embed_type: Optional[str] = None,
        addition_time_embed_dim: Optional[int] = None,
        num_class_embeds: Optional[int] = None,
        upcast_attention: bool = False,
        resnet_time_scale_shift: str = "default",
        resnet_skip_time_act: bool = False,
        resnet_out_scale_factor: int = 1.0,
        time_embedding_type: str = "positional",
        time_embedding_dim: Optional[int] = None,
        time_embedding_act_fn: Optional[str] = None,
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
        projection_class_embeddings_input_dim: Optional[int] = None,
        attention_type: str = "default",
        class_embeddings_concat: bool = False,
        mid_block_only_cross_attention: Optional[bool] = None,
        cross_attention_norm: Optional[str] = None,
        addition_embed_type_num_heads=64,
    ):
        super().__init__()

        self.sample_size = sample_size

        if num_attention_heads is not None:
            raise ValueError(
                "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
            )

        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = num_attention_heads or attention_head_dim

        # Check inputs
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )

        # input
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )

        # time
        if time_embedding_type == "fourier":
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
            )

        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )

        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            encoder_hid_dim_type = "text_proj"
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )

        if encoder_hid_dim_type == "text_proj":
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        elif encoder_hid_dim_type == "text_image_proj":
            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
        elif encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
        else:
            self.encoder_hid_proj = None

        # class embedding
        if class_embed_type is None and num_class_embeds is not None:
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        elif class_embed_type == "simple_projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                )
            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
        else:
            self.class_embedding = None

        if addition_embed_type == "text":
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            else:
                text_time_embedding_from_dim = cross_attention_dim

            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
        elif addition_embed_type == "text_image":
            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
        elif addition_embed_type == "text_time":
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        elif addition_embed_type is not None:
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")

        if time_embedding_act_fn is None:
            self.time_embed_act = None
        else:
            self.time_embed_act = get_activation(time_embedding_act_fn)

        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

        if isinstance(only_cross_attention, bool):
            if mid_block_only_cross_attention is None:
                mid_block_only_cross_attention = only_cross_attention

            only_cross_attention = [only_cross_attention] * len(down_block_types)

        if mid_block_only_cross_attention is None:
            mid_block_only_cross_attention = False

        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)

        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

        if class_embeddings_concat:
            # The time embeddings are concatenated with the class embeddings. The dimension of the
            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
            # regular time embeddings
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block[i],
                transformer_layers_per_block=transformer_layers_per_block[i],
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=blocks_time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                cross_attention_dim=cross_attention_dim[i],
                num_attention_heads=num_attention_heads[i],
                downsample_padding=downsample_padding,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                only_cross_attention=only_cross_attention[i],
                upcast_attention=upcast_attention,
                resnet_time_scale_shift=resnet_time_scale_shift,
                attention_type=attention_type,
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
                cross_attention_norm=cross_attention_norm,
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
            )
            self.down_blocks.append(down_block)

        # mid
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
                transformer_layers_per_block=transformer_layers_per_block[-1],
                in_channels=block_out_channels[-1],
                temb_channels=blocks_time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
                cross_attention_dim=cross_attention_dim[-1],
                num_attention_heads=num_attention_heads[-1],
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
                attention_type=attention_type,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=blocks_time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                cross_attention_dim=cross_attention_dim[-1],
                attention_head_dim=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
                skip_time_act=resnet_skip_time_act,
                only_cross_attention=mid_block_only_cross_attention,
                cross_attention_norm=cross_attention_norm,
            )
        elif mid_block_type is None:
            self.mid_block = None
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")

        # count how many layers upsample the images
        self.num_upsamplers = 0

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        reversed_num_attention_heads = list(reversed(num_attention_heads))
        reversed_layers_per_block = list(reversed(layers_per_block))
        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
        only_cross_attention = list(reversed(only_cross_attention))

        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            is_final_block = i == len(block_out_channels) - 1

            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False

            up_block = get_up_block(
                up_block_type,
                num_layers=reversed_layers_per_block[i] + 1,
                transformer_layers_per_block=reversed_transformer_layers_per_block[i],
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=blocks_time_embed_dim,
                add_upsample=add_upsample,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                cross_attention_dim=reversed_cross_attention_dim[i],
                num_attention_heads=reversed_num_attention_heads[i],
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                only_cross_attention=only_cross_attention[i],
                upcast_attention=upcast_attention,
                resnet_time_scale_shift=resnet_time_scale_shift,
                attention_type=attention_type,
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
                cross_attention_norm=cross_attention_norm,
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        
        # # out
        # if norm_num_groups is not None:
        #     self.conv_norm_out = nn.GroupNorm(
        #         num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
        #     )

        #     self.conv_act = get_activation(act_fn)

        # else:
        #     self.conv_norm_out = None
        #     self.conv_act = None

        # conv_out_padding = (conv_out_kernel - 1) // 2
        # self.conv_out = nn.Conv2d(
        #     block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        # )
        
        # Diff vs diffusers-0.21.4/src/diffusers/models/unet_2d_condition.py
        # skip last cross attention for slight acceleration and for DDP training
        # The following parameters (cross-attention for the last layer) 
        # and conv_out are not involved in the gradient calculation of the model
        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
        self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
        self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
        self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
        self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
        self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
        self.up_blocks[3].attentions[2].proj_out = Identity()

        if attention_type in ["gated", "gated-text-image"]:
            positive_len = 768
            if isinstance(cross_attention_dim, int):
                positive_len = cross_attention_dim
            elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
                positive_len = cross_attention_dim[0]

            feature_type = "text-only" if attention_type == "gated" else "text-image"
            self.position_net = PositionNet(
                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
            )

    @property
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    def set_attn_processor(
        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
    ):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnAddedKVProcessor()
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnProcessor()
        else:
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        self.set_attn_processor(processor)

    def set_attention_slice(self, slice_size):
        r"""
        Enable sliced attention computation.

        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
        several steps. This is useful for saving some memory in exchange for a small decrease in speed.

        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
                fn_recursive_retrieve_sliceable_dims(child)

        # retrieve number of attention layers
        for module in self.children():
            fn_recursive_retrieve_sliceable_dims(module)

        num_sliceable_layers = len(sliceable_head_dims)

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
            slice_size = num_sliceable_layers * [1]

        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )

        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
        r"""
        The [`UNet2DConditionModel`] forward method.

        Args:
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.FloatTensor`):
                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
            encoder_attention_mask (`torch.Tensor`):
                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
                which adds large negative values to the attention scores corresponding to "discard" tokens.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
            added_cond_kwargs: (`dict`, *optional*):
                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
                are passed along to the UNet blocks.

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
                a `tuple` is returned where the first element is the sample tensor.
        """
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)
        aug_emb = None

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb

        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
        elif self.config.addition_embed_type == "text_image":
            # Kandinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
            aug_emb = self.add_embedding(text_embs, image_embs)
        elif self.config.addition_embed_type == "text_time":
            # SDXL - style
            if "text_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                )
            text_embeds = added_cond_kwargs.get("text_embeds")
            if "time_ids" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                )
            time_ids = added_cond_kwargs.get("time_ids")
            time_embeds = self.add_time_proj(time_ids.flatten())
            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
            add_embeds = add_embeds.to(emb.dtype)
            aug_emb = self.add_embedding(add_embeds)
        elif self.config.addition_embed_type == "image":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            aug_emb = self.add_embedding(image_embs)
        elif self.config.addition_embed_type == "image_hint":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            hint = added_cond_kwargs.get("hint")
            aug_emb, hint = self.add_embedding(image_embs, hint)
            sample = torch.cat([sample, hint], dim=1)

        emb = emb + aug_emb if aug_emb is not None else emb

        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
        # 2. pre-process
        sample = self.conv_in(sample)

        # 2.5 GLIGEN position net
        if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
            cross_attention_kwargs = cross_attention_kwargs.copy()
            gligen_args = cross_attention_kwargs.pop("gligen")
            cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}

        # 3. down

        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None

        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                # For t2i-adapter CrossAttnDownBlock2D
                additional_residuals = {}
                if is_adapter and len(down_block_additional_residuals) > 0:
                    additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)

                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                    **additional_residuals,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

                if is_adapter and len(down_block_additional_residuals) > 0:
                    sample += down_block_additional_residuals.pop(0)

            down_block_res_samples += res_samples

        if is_controlnet:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)

            down_block_res_samples = new_down_block_res_samples

        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
                encoder_attention_mask=encoder_attention_mask,
            )
            # To support T2I-Adapter-XL
            if (
                is_adapter
                and len(down_block_additional_residuals) > 0
                and sample.shape == down_block_additional_residuals[0].shape
            ):
                sample += down_block_additional_residuals.pop(0)

        if is_controlnet:
            sample = sample + mid_block_additional_residual

        # 5. up
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    upsample_size=upsample_size,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    upsample_size=upsample_size
                )

        if not return_dict:
            return (sample,)

        return UNet2DConditionOutput(sample=sample)
    
    @classmethod
    def load_referencenet(cls, pretrained_model_path):
        print(f"loaded ReferenceNet's pretrained weights  ...")

        config = {
            "_class_name": "UNet2DConditionModel",
            "_diffusers_version": "0.6.0",
            "act_fn": "silu",
            "attention_head_dim": 8,
            "block_out_channels": [320, 640, 1280, 1280],
            "center_input_sample": False,
            "cross_attention_dim": 768,
            "down_block_types": [
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D"
            ],
            "downsample_padding": 1,
            "flip_sin_to_cos": True,
            "freq_shift": 0,
            "in_channels": 4,
            "layers_per_block": 2,
            "mid_block_scale_factor": 1,
            "norm_eps": 1e-05,
            "norm_num_groups": 32,
            "out_channels": 4,
            "sample_size": 64,
            "up_block_types": [
                "UpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D"
            ]
        }

        # from diffusers.utils import WEIGHTS_NAME
        model = cls.from_config(config)

        # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)

        if not pretrained_model_path is str:
            state_dict = pretrained_model_path
            print(f"loaded ReferenceNet's pretrained weights...")
        else:
            model_file = pretrained_model_path
            
            if not os.path.isfile(model_file):
                raise RuntimeError(f"{model_file} does not exist")
            state_dict = torch.load(model_file, map_location="cpu")

        m, u = model.load_state_dict(state_dict, strict=True)
        if m or u:
            print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
            # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
        
        # params = [p.numel() for n, p in model.named_parameters() if "2D" in n]
        # print(f"### 2D Module Parameters: {sum(params) / 1e6} M")
        
        params = [p.numel() for n, p in model.named_parameters()]
        print(f"### Module Parameters: {sum(params) / 1e6} M")
        
        return model


================================================
FILE: models/ReferenceNet_attention.py
================================================
# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py

import torch
import torch.nn.functional as F

from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from diffusers.models.attention import BasicTransformerBlock
from .attention import BasicTransformerBlock as _BasicTransformerBlock

def torch_dfs(model: torch.nn.Module):
    result = [model]
    for child in model.children():
        result += torch_dfs(child)
    return result


class ReferenceNetAttention():
    
    def __init__(self, 
                 unet,
                 mode="write",
                 do_classifier_free_guidance=False,
                 attention_auto_machine_weight = float('inf'),
                 gn_auto_machine_weight = 1.0,
                 style_fidelity = 1.0,
                 reference_attn=True,
                 fusion_blocks="full",
                 batch_size=1, 
                 is_image=False,
                 dtype = torch.float32, 
                 ) -> None:
        # 10. Modify self attention and group norm
        self.unet = unet
        assert mode in ["read", "write"]
        assert fusion_blocks in ["midup", "full"]
        self.reference_attn = reference_attn
        self.fusion_blocks = fusion_blocks
        self.register_reference_hooks(
            mode, 
            do_classifier_free_guidance,
            attention_auto_machine_weight,
            gn_auto_machine_weight,
            style_fidelity,
            reference_attn,
            fusion_blocks=fusion_blocks,
            batch_size=batch_size, 
            is_image=is_image,
            dtype=dtype, 
        )
        self.dtype = dtype

    def register_reference_hooks(
            self, 
            mode, 
            do_classifier_free_guidance,
            attention_auto_machine_weight,
            gn_auto_machine_weight,
            style_fidelity,
            reference_attn,
            # dtype=torch.float16,
            dtype=torch.float32,
            batch_size=1, 
            num_images_per_prompt=1, 
            device=torch.device("cpu"), 
            fusion_blocks='midup',
            is_image=False,
        ):
        MODE = mode
        do_classifier_free_guidance = do_classifier_free_guidance
        attention_auto_machine_weight = attention_auto_machine_weight
        gn_auto_machine_weight = gn_auto_machine_weight
        style_fidelity = style_fidelity
        reference_attn = reference_attn
        fusion_blocks = fusion_blocks
        num_images_per_prompt = num_images_per_prompt
        dtype=dtype
        if do_classifier_free_guidance:
            uc_mask = (
                torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
                .to(device)
                .bool()
            )
        else:
            uc_mask = (
                torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
                .to(device)
                .bool()
            )
        
        def hacked_basic_transformer_inner_forward(
            self,
            hidden_states: torch.FloatTensor,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            timestep: Optional[torch.LongTensor] = None,
            cross_attention_kwargs: Dict[str, Any] = None,
            class_labels: Optional[torch.LongTensor] = None,
            video_length=None,
        ):
          
            if self.use_ada_layer_norm:
                norm_hidden_states = self.norm1(hidden_states, timestep)
            elif self.use_ada_layer_norm_zero:
                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
                )
            else:
                norm_hidden_states = self.norm1(hidden_states)
      
            # 1. Self-Attention
            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
            if self.only_cross_attention:
                attn_output = self.attn1(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                    attention_mask=attention_mask,
                    **cross_attention_kwargs,
                )
            else:
                if MODE == "write":
                    # print(norm_hidden_states.dtype)
                    # exit()
                    # for bank in self.bank:
                    #     print(bank.dtype) 
                    self.bank.append(norm_hidden_states.clone())
                    attn_output = self.attn1(
                        norm_hidden_states,
                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                        attention_mask=attention_mask,
                        **cross_attention_kwargs,
                    )
                    # for bank in self.bank:
                    #     print(bank.dtype) 
                    
                if MODE == "read":
                    # for bank in self.bank:
                    #     bank = bank.to(dtype)
                    self.bank = [bank.to(dtype) for bank in self.bank]


                    if not is_image:
                        self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
                    # modify Reference Sec 3.2.2
                    # print(f"#### norm_hidden_states #### {norm_hidden_states.size()}")
                    # print(f"#### self.bank #### {self.bank[0].size()}")
                    modify_norm_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1)
                    # print("########## modify_norm_hidden_states ",modify_norm_hidden_states.dtype,"  ##########") # torch.float32
                    # print("########## self.bank[0] ",self.bank[0].dtype,"  ##########") # torch.float16 -> torch.float32
                    # print(f"#### modify_norm_hidden_states #### {modify_norm_hidden_states.size()}")
                    # print(modify_norm_hidden_states.dtype) 32
                    # for bank in self.bank:
                    #     print(bank.dtype) 
                    # print(norm_hidden_states.dtype) 16

                    # modify_norm_hidden_states = modify_norm_hidden_states.to(self.dtype)
                    
                    # exit()
                    hidden_states_uc = self.attn1(modify_norm_hidden_states, 
                                                encoder_hidden_states=modify_norm_hidden_states,
                                                attention_mask=attention_mask)[:,:hidden_states.shape[-2],:] + hidden_states
                    
                    # hidden_states_uc = self.attn1(norm_hidden_states, 
                    #                             encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
                    #                             attention_mask=attention_mask) + hidden_states
                    hidden_states_c = hidden_states_uc.clone()
                    _uc_mask = uc_mask.clone()

                    # print(f"#### hidden_states_uc #### {hidden_states_uc.shape}")
                    # print(f'  _uc_mask   ', _uc_mask.shape)
                    # print(f'  norm_hidden_states   ', norm_hidden_states.shape)
                    if do_classifier_free_guidance:
                        # if hidden_states.shape[0] != _uc_mask.shape[0]:
                        _uc_mask = (
                            torch.Tensor([1]  + [0] * (hidden_states.shape[0]-1))
                            .to(device)
                            .bool()
                        )

                        # _uc_mask = (
                        #     torch.Tensor([1] * (hidden_states.shape[0]-1) + [0] )
                        #     .to(device)
                        #     .bool()
                        # )
                        hidden_states_c[_uc_mask] = self.attn1(
                            norm_hidden_states[_uc_mask],
                            encoder_hidden_states=norm_hidden_states[_uc_mask],
                            attention_mask=attention_mask,
                        ) + hidden_states[_uc_mask]
                    hidden_states = hidden_states_c.clone()
                        
                    # self.bank.clear()
                    if self.attn2 is not None:
                        # Cross-Attention
                        norm_hidden_states = (
                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
                        )
                        hidden_states = (
                            self.attn2(
                                norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
                            )
                            + hidden_states
                        )

                    # Feed-forward
                    hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states

                    # Temporal-Attention
                    if not is_image:
                        if self.unet_use_temporal_attention:
                            d = hidden_states.shape[1]
                            hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
                            norm_hidden_states = (
                                self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
                            )
                            hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
                            hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)

                    return hidden_states
                
            if self.use_ada_layer_norm_zero:
                attn_output = gate_msa.unsqueeze(1) * attn_output
            hidden_states = attn_output + hidden_states

            if self.attn2 is not None:
                norm_hidden_states = (
                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
                )

                # 2. Cross-Attention
                attn_output = self.attn2(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=encoder_attention_mask,
                    **cross_attention_kwargs,
                )
                hidden_states = attn_output + hidden_states

            # 3. Feed-forward
            norm_hidden_states = self.norm3(hidden_states)

            if self.use_ada_layer_norm_zero:
                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

            ff_output = self.ff(norm_hidden_states)

            if self.use_ada_layer_norm_zero:
                ff_output = gate_mlp.unsqueeze(1) * ff_output

            hidden_states = ff_output + hidden_states

            return hidden_states

        if self.reference_attn:
            if self.fusion_blocks == "midup":
                attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
            elif self.fusion_blocks == "full":
                attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]            
            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])

            for i, module in enumerate(attn_modules):
                module._original_inner_forward = module.forward
                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
                module.bank = []
                module.attn_weight = float(i) / float(len(attn_modules))
    
    # def update(self, writer, dtype=torch.float16):
    def update(self, writer, dtype=torch.float32):
        if self.reference_attn:
            if self.fusion_blocks == "midup":
                reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
                writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
            elif self.fusion_blocks == "full":
                reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
                writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, _BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
            reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])    
            writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
            
            # print('reader_attn_modules:',reader_attn_modules)
            # print('writer_attn_modules:',writer_attn_modules)
            if len(reader_attn_modules) == 0:
                print('reader_attn_modules is null')
                assert False
            if len(writer_attn_modules) == 0:
                print('writer_attn_modules is null')
                assert False
              
            for r, w in zip(reader_attn_modules, writer_attn_modules):
                r.bank = [v.clone().to(dtype) for v in w.bank]
                # w.bank.clear()
    
    def clear(self):
        if self.reference_attn:
            if self.fusion_blocks == "midup":
                reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
            elif self.fusion_blocks == "full":
                reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
            reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
            for r in reader_attn_modules:
                r.bank.clear()


================================================
FILE: models/ReferenceNet_attention_fp16.py
================================================
# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py

import torch
import torch.nn.functional as F

from einops import rearrange
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from diffusers.models.attention import BasicTransformerBlock
from .attention import BasicTransformerBlock as _BasicTransformerBlock

def torch_dfs(model: torch.nn.Module):
    result = [model]
    for child in model.children():
        result += torch_dfs(child)
    return result


class ReferenceNetAttention():
    
    def __init__(self, 
                 unet,
                 mode="write",
                 do_classifier_free_guidance=False,
                 attention_auto_machine_weight = float('inf'),
                 gn_auto_machine_weight = 1.0,
                 style_fidelity = 1.0,
                 reference_attn=True,
                 fusion_blocks="full",
                 batch_size=1, 
                 is_image=False,
                 ) -> None:
        # 10. Modify self attention and group norm
        self.unet = unet
        assert mode in ["read", "write"]
        assert fusion_blocks in ["midup", "full"]
        self.reference_attn = reference_attn
        self.fusion_blocks = fusion_blocks
        self.register_reference_hooks(
            mode, 
            do_classifier_free_guidance,
            attention_auto_machine_weight,
            gn_auto_machine_weight,
            style_fidelity,
            reference_attn,
            fusion_blocks,
            batch_size=batch_size, 
            is_image=is_image,
        )

    def register_reference_hooks(
            self, 
            mode, 
            do_classifier_free_guidance,
            attention_auto_machine_weight,
            gn_auto_machine_weight,
            style_fidelity,
            reference_attn,
            # dtype=torch.float16,
            dtype=torch.float16,
            batch_size=1, 
            num_images_per_prompt=1, 
            device=torch.device("cpu"), 
            fusion_blocks='midup',
            is_image=False,
        ):
        MODE = mode
        do_classifier_free_guidance = do_classifier_free_guidance
        attention_auto_machine_weight = attention_auto_machine_weight
        gn_auto_machine_weight = gn_auto_machine_weight
        style_fidelity = style_fidelity
        reference_attn = reference_attn
        fusion_blocks = fusion_blocks
        num_images_per_prompt = num_images_per_prompt
        dtype=dtype
        if do_classifier_free_guidance:
            uc_mask = (
                torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
                .to(device)
                .bool()
            )
        else:
            uc_mask = (
                torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
                .to(device)
                .bool()
            )
        
        def hacked_basic_transformer_inner_forward(
            self,
            hidden_states: torch.FloatTensor,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            timestep: Optional[torch.LongTensor] = None,
            cross_attention_kwargs: Dict[str, Any] = None,
            class_labels: Optional[torch.LongTensor] = None,
            video_length=None,
        ):
            if self.use_ada_layer_norm:
                norm_hidden_states = self.norm1(hidden_states, timestep)
            elif self.use_ada_layer_norm_zero:
                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
                )
            else:
                norm_hidden_states = self.norm1(hidden_states)

            # 1. Self-Attention
            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
            if self.only_cross_attention:
                attn_output = self.attn1(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                    attention_mask=attention_mask,
                    **cross_attention_kwargs,
                )
            else:
                if MODE == "write":
                    self.bank.append(norm_hidden_states.clone())
                    attn_output = self.attn1(
                        norm_hidden_states,
                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                        attention_mask=attention_mask,
                        **cross_attention_kwargs,
                    )
                if MODE == "read":
                    if not is_image:
                        self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
                    # modify Reference Sec 3.2.2
                    modify_norm_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1)

                    hidden_states_uc = self.attn1(modify_norm_hidden_states, 
                                                encoder_hidden_states=modify_norm_hidden_states,
                                                attention_mask=attention_mask)[:,:hidden_states.shape[-2],:] + hidden_states
                    

                    hidden_states_c = hidden_states_uc.clone()
                    _uc_mask = uc_mask.clone()
                    if do_classifier_free_guidance:
                        if hidden_states.shape[0] != _uc_mask.shape[0]:
                            _uc_mask = (
                                torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
                                .to(device)
                                .bool()
                            )
                        hidden_states_c[_uc_mask] = self.attn1(
                            norm_hidden_states[_uc_mask],
                            encoder_hidden_states=norm_hidden_states[_uc_mask],
                            attention_mask=attention_mask,
                        ) + hidden_states[_uc_mask]
                    hidden_states = hidden_states_c.clone()
                        
                    # self.bank.clear()

                    
                    if self.attn2 is not None:
                        # Cross-Attention
                        norm_hidden_states = (
                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
                        )
                        hidden_states = (
                            self.attn2(
                                norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
                            )
                            + hidden_states
                        )

                    # Feed-forward
                    hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states

                    # Temporal-Attention
                    if not is_image:
                        if self.unet_use_temporal_attention:
                            d = hidden_states.shape[1]
                            hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
                            norm_hidden_states = (
                                self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
                            )
                            hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
                            hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)

                    return hidden_states
                
            if self.use_ada_layer_norm_zero:
                attn_output = gate_msa.unsqueeze(1) * attn_output
            hidden_states = attn_output + hidden_states

            if self.attn2 is not None:
                norm_hidden_states = (
                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
                )

                # 2. Cross-Attention
                attn_output = self.attn2(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=encoder_attention_mask,
                    **cross_attention_kwargs,
                )
                hidden_states = attn_output + hidden_states

            # 3. Feed-forward
            norm_hidden_states = self.norm3(hidden_states)

            if self.use_ada_layer_norm_zero:
                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

            ff_output = self.ff(norm_hidden_states)

            if self.use_ada_layer_norm_zero:
                ff_output = gate_mlp.unsqueeze(1) * ff_output

            hidden_states = ff_output + hidden_states

            return hidden_states

        if self.reference_attn:
            if self.fusion_blocks == "midup":
                attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
            elif self.fusion_blocks == "full":
                attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]            
            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])

            for i, module in enumerate(attn_modules):
                module._original_inner_forward = module.forward
                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
                module.bank = []
                module.attn_weight = float(i) / float(len(attn_modules))
    
    # def update(self, writer, dtype=torch.float16):
    def update(self, writer, dtype=torch.float16):
        if self.reference_attn:
            if self.fusion_blocks == "midup":
                reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
                writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
            elif self.fusion_blocks == "full":
                reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
                writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, _BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
            reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])    
            writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
            
            if len(reader_attn_modules) == 0:
                print('reader_attn_modules is null')
                assert False
            if len(writer_attn_modules) == 0:
                print('writer_attn_modules is null')
                assert False
              
            for r, w in zip(reader_attn_modules, writer_attn_modules):
                r.bank = [v.clone().to(dtype) for v in w.bank]
                # w.bank.clear()
    
    def clear(self):
        if self.reference_attn:
            if self.fusion_blocks == "midup":
                reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
            elif self.fusion_blocks == "full":
                reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
            reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
            for r in reader_attn_modules:
                r.bank.clear()


================================================
FILE: models/attention.py
================================================
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..  
# *************************************************************************

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import FeedForward, AdaLayerNorm
from diffusers.models.attention import Attention as CrossAttention

from einops import rearrange, repeat

@dataclass
class Transformer3DModelOutput(BaseOutput):
    sample: torch.FloatTensor


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None


class Transformer3DModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        use_linear_projection: bool = False,
        only_cross_attention: bool = False,
        upcast_attention: bool = False,

        unet_use_cross_frame_attention=None,
        unet_use_temporal_attention=None,
    ):
        super().__init__()
        self.use_linear_projection = use_linear_projection
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim

        # Define input layers
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        if use_linear_projection:
            self.proj_in = nn.Linear(in_channels, inner_dim)
        else:
            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

        # Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    only_cross_attention=only_cross_attention,
                    upcast_attention=upcast_attention,

                    unet_use_cross_frame_attention=unet_use_cross_frame_attention,
                    unet_use_temporal_attention=unet_use_temporal_attention,
                )
                for d in range(num_layers)
            ]
        )

        # 4. Define output layers
        if use_linear_projection:
            self.proj_out = nn.Linear(in_channels, inner_dim)
        else:
            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
        # Input
        assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
        video_length = hidden_states.shape[2]
        hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
        # JH: need not repeat when a list of prompts are given 
        if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
            encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)

        batch, channel, height, weight = hidden_states.shape
        residual = hidden_states

        hidden_states = self.norm(hidden_states)
        if not self.use_linear_projection:
            hidden_states = self.proj_in(hidden_states)
            inner_dim = hidden_states.shape[1]
            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
        else:
            inner_dim = hidden_states.shape[1]
            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
            hidden_states = self.proj_in(hidden_states)

        # Blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                timestep=timestep,
                video_length=video_length
            )

        # Output
        if not self.use_linear_projection:
            hidden_states = (
                hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
            )
            hidden_states = self.proj_out(hidden_states)
        else:
            hidden_states = self.proj_out(hidden_states)
            hidden_states = (
                hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
            )

        output = hidden_states + residual

        output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
        if not return_dict:
            return (output,)

        return Transformer3DModelOutput(sample=output)


class BasicTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        dropout=0.0,
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
        only_cross_attention: bool = False,
        upcast_attention: bool = False,

        unet_use_cross_frame_attention = None,
        unet_use_temporal_attention = None,
    ):
        super().__init__()
        self.only_cross_attention = only_cross_attention
        self.use_ada_layer_norm = num_embeds_ada_norm is not None
        self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
        self.unet_use_temporal_attention = unet_use_temporal_attention

        # SC-Attn
        assert unet_use_cross_frame_attention is not None
        if unet_use_cross_frame_attention:
            self.attn1 = SparseCausalAttention2D(
                query_dim=dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                cross_attention_dim=cross_attention_dim if only_cross_attention else None,
                upcast_attention=upcast_attention,
            )
        else:
            self.attn1 = CrossAttention(
                query_dim=dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )
        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

        # Cross-Attn
        if cross_attention_dim is not None:
            self.attn2 = CrossAttention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )
        else:
            self.attn2 = None

        if cross_attention_dim is not None:
            self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
        else:
            self.norm2 = None

        # Feed-forward
        self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
        self.norm3 = nn.LayerNorm(dim)
        self.use_ada_layer_norm_zero = False
        
        # Temp-Attn
        assert unet_use_temporal_attention is not None
        if unet_use_temporal_attention:
            self.attn_temp = CrossAttention(
                query_dim=dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )
            nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
            self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)

    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
        if not is_xformers_available():
            print("Here is how to install it")
            raise ModuleNotFoundError(
                "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                " xformers",
                name="xformers",
            )
        elif not torch.cuda.is_available():
            raise ValueError(
                "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
                " available for GPU "
            )
        else:
            try:
                # Make sure we can run the memory efficient attention
                _ = xformers.ops.memory_efficient_attention(
                    torch.randn((1, 2, 40), device="cuda"),
                    torch.randn((1, 2, 40), device="cuda"),
                    torch.randn((1, 2, 40), device="cuda"),
                )
            except Exception as e:
                raise e
            self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
            if self.attn2 is not None:
                self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
            # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers

    def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
        # SparseCausal-Attention
        norm_hidden_states = (
            self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
        )

        # if self.only_cross_attention:
        #     hidden_states = (
        #         self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
        #     )
        # else:
        #     hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states

        # pdb.set_trace()
        if self.unet_use_cross_frame_attention:
            hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
        else:
            hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states

        if self.attn2 is not None:
            # Cross-Attention
            norm_hidden_states = (
                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
            )
            hidden_states = (
                self.attn2(
                    norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
                )
                + hidden_states
            )

        # Feed-forward
        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states

        # Temporal-Attention
        if self.unet_use_temporal_attention:
            d = hidden_states.shape[1]
            hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
            norm_hidden_states = (
                self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
            )
            hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
            hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)

        return hidden_states


================================================
FILE: models/clip_adapter.py
================================================

import torch
import torch.nn as nn
import torch.nn.functional as F

class NextImageFeaturePredictor(nn.Module):
    def __init__(self, input_feature_dim=768, output_feature_dim=768, hidden_dim=1024):
        super(NextImageFeaturePredictor, self).__init__()
        self.input_feature_dim = input_feature_dim
        self.output_feature_dim = output_feature_dim
        self.hidden_dim = hidden_dim
        
        # Since we concatenate the current and final features, the input dimension is doubled
        self.fc1 = nn.Linear(self.input_feature_dim * 2, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim // 2)
        self.fc3 = nn.Linear(self.hidden_dim // 2, self.output_feature_dim)

    def forward(self, current_img_feature, final_img_feature):
        # Concatenate the current and final image features along the last dimension
        x = torch.cat((current_img_feature, final_img_feature), dim=-1)
        # print(current_img_feature.shape, final_img_feature.shape)
        # print(x.shape)
        x = x.view(-1, self.input_feature_dim * 2)  # Flatten the input for fully connected layers
        # print(x.shape)

        # Forward pass
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        # exit()
        
        # Reshape the output to match the desired output shape
        x = x.view(-1, 1, 768)
        return x

# # Parameters
# input_feature_dim = 768  # Feature dimension of each vector in the sequence
# output_feature_dim = 768  # Same as input for predicting next feature
# hidden_dim = 1024  # Example hidden dimension, can be adjusted

# # Initialize the model
# model = NextImageFeaturePredictor(input_feature_dim, output_feature_dim, hidden_dim)

# # Example tensors for current and final image features
# current_img_feature = torch.randn(1, 50, 768)
# final_img_feature = torch.randn(1, 50, 768)

# # Predict the next image feature
# next_img_feature = model(current_img_feature, final_img_feature)
# print(next_img_feature.shape)  # Expected output shape: (1, 50, 768)


================================================
FILE: models/hack_cur_image_guider.py
================================================
import os
import torch
import torch.nn as nn
import torch.nn.init as init
from einops import rearrange
import numpy as np

class Hack_CurImageGuider(nn.Module):
    def __init__(self, in_channels=3, noise_latent_channels=320):
        super(Hack_CurImageGuider, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=3, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),

            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.SiLU()
        )

        # Final projection layer
        self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)

        # Initialize layers
        self._initialize_weights()

        self.scale = nn.Parameter(torch.ones(1) * 2)

    # def _initialize_weights(self):
    #     # Initialize weights with Gaussian distribution and zero out the final layer
    #     for m in self.conv_layers:
    #         if isinstance(m, nn.Conv2d):
    #             init.normal_(m.weight, mean=0.0, std=0.02)
    #             if m.bias is not None:
    #                 init.zeros_(m.bias)

    #     init.zeros_(self.final_proj.weight)
    #     if self.final_proj.bias is not None:
    #         init.zeros_(self.final_proj.bias)
    
    def _initialize_weights(self):
        # Initialize weights with He initialization and zero out the biases
        for m in self.conv_layers:
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
                if m.bias is not None:
                    init.zeros_(m.bias)

        # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well)
        init.zeros_(self.final_proj.weight)
        if self.final_proj.bias is not None:
            init.zeros_(self.final_proj.bias)


    def forward(self, x):
        x = self.conv_layers(x)
        x = self.final_proj(x)
        # print(self.scale)
        return x * self.scale

    @classmethod
    def from_pretrained(cls,pretrained_model_path, in_channels=3):
        # pretrained_model_path is alreay a state_dict
        # print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
        if not pretrained_model_path is str:
            state_dict = pretrained_model_path
            print(f"loaded Current Image Guider's pretrained weights...")

        else:
            if not os.path.exists(pretrained_model_path):
                print(f"There is no model file in {pretrained_model_path}")
            print(f"loaded Current Image Guider's pretrained weights from {pretrained_model_path} ...")

            state_dict = torch.load(pretrained_model_path, map_location="cpu")


        model = Hack_CurImageGuider(in_channels=in_channels, noise_latent_channels=320)
                
        m, u = model.load_state_dict(state_dict, strict=True)
        # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")        
        params = [p.numel() for n, p in model.named_parameters()]
        
        return model


================================================
FILE: models/hack_unet2d.py
================================================
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint
# from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionModel,UNet2DConditionOutput,logger


class Hack_UNet2DConditionModel(UNet2DConditionModel):
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        latent_cur: torch.Tensor, # new add

        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
        r"""
        The [`UNet2DConditionModel`] forward method.

        Args:
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.FloatTensor`):
                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
            encoder_attention_mask (`torch.Tensor`):
                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
                which adds large negative values to the attention scores corresponding to "discard" tokens.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
            added_cond_kwargs: (`dict`, *optional*):
                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
                are passed along to the UNet blocks.

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
                a `tuple` is returned where the first element is the sample tensor.
        """
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
        # expects mask of shape:
        #   [batch, key_tokens]
        # adds singleton query_tokens dimension:
        #   [batch,                    1, key_tokens]
        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
        if attention_mask is not None:
            # assume that mask is expressed as:
            #   (1 = keep,      0 = discard)
            # convert mask into a bias that can be added to attention scores:
            #       (keep = +0,     discard = -10000.0)
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)
        aug_emb = None

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb

        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
        elif self.config.addition_embed_type == "text_image":
            # Kandinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
            aug_emb = self.add_embedding(text_embs, image_embs)
        elif self.config.addition_embed_type == "text_time":
            # SDXL - style
            if "text_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                )
            text_embeds = added_cond_kwargs.get("text_embeds")
            if "time_ids" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                )
            time_ids = added_cond_kwargs.get("time_ids")
            time_embeds = self.add_time_proj(time_ids.flatten())
            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
            add_embeds = add_embeds.to(emb.dtype)
            aug_emb = self.add_embedding(add_embeds)
        elif self.config.addition_embed_type == "image":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            aug_emb = self.add_embedding(image_embs)
        elif self.config.addition_embed_type == "image_hint":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            hint = added_cond_kwargs.get("hint")
            aug_emb, hint = self.add_embedding(image_embs, hint)
            sample = torch.cat([sample, hint], dim=1)

        emb = emb + aug_emb if aug_emb is not None else emb

        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
        # 2. pre-process
        sample = self.conv_in(sample) 


        sample = sample + latent_cur

        # 2.5 GLIGEN position net
        if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
            cross_attention_kwargs = cross_attention_kwargs.copy()
            gligen_args = cross_attention_kwargs.pop("gligen")
            cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}

        # 3. down
        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None

        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                # For t2i-adapter CrossAttnDownBlock2D
                additional_residuals = {}
                if is_adapter and len(down_block_additional_residuals) > 0:
                    additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)

                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                    **additional_residuals,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)

                if is_adapter and len(down_block_additional_residuals) > 0:
                    sample += down_block_additional_residuals.pop(0)

            down_block_res_samples += res_samples

        if is_controlnet:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)

            down_block_res_samples = new_down_block_res_samples

        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
                encoder_attention_mask=encoder_attention_mask,
            )
            # To support T2I-Adapter-XL
            if (
                is_adapter
                and len(down_block_additional_residuals) > 0
                and sample.shape == down_block_additional_residuals[0].shape
            ):
                sample += down_block_additional_residuals.pop(0)

        if is_controlnet:
            sample = sample + mid_block_additional_residual

        # 5. up
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    upsample_size=upsample_size,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    upsample_size=upsample_size,
                    scale=lora_scale,
                )

        # 6. post-process
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if not return_dict:
            return (sample,)

        return UNet2DConditionOutput(sample=sample)


================================================
FILE: models/image_processor.py
================================================
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T


class LPIPS_Image_Processor():
    def __init__(self):
        print("Loading LPIPS model")

    def process(self, img):
        # input PIL
        img = T.ToTensor()(img).unsqueeze(0)

        # Normalize to [-1, 1]
        norm_img = 2 * img - 1

        return img, norm_img


class Seg_Image_Processor():
    def __init__(self):
        super(Seg_Image_Processor, self).__init
Download .txt
gitextract_kiog4jd6/

├── README.md
├── configs/
│   ├── inference/
│   │   └── inference.yaml
│   └── train/
│       ├── train_mask_gen.yaml
│       └── train_renderer.yaml
├── data/
│   ├── sample_data/
│   │   └── train/
│   │       └── rgb/
│   │           └── example/
│   │               └── last_aligned_frame_inv.json
│   └── sample_data_processed/
│       └── train/
│           ├── llava_json.json
│           ├── rgb/
│           │   └── example/
│           │       └── last_aligned_frame_inv.json
│           └── text/
│               └── example/
│                   ├── 2_0:40.json
│                   ├── 3_0:53.json
│                   ├── 4_1:12.json
│                   ├── 5_1:28.json
│                   ├── 6_1:58.json
│                   ├── 7_2:19.json
│                   ├── 8_2:36.json
│                   ├── 9_2:43.json
│                   └── white_10_3:21.json
├── data_processing/
│   ├── run_llava/
│   │   ├── main.py
│   │   ├── make_list.py
│   │   └── utils.py
│   └── run_lpips/
│       └── main.py
├── dataset/
│   └── dataset.py
├── demo.py
├── models/
│   ├── ReferenceEncoder.py
│   ├── ReferenceNet.py
│   ├── ReferenceNet_attention.py
│   ├── ReferenceNet_attention_fp16.py
│   ├── attention.py
│   ├── clip_adapter.py
│   ├── hack_cur_image_guider.py
│   ├── hack_unet2d.py
│   ├── image_processor.py
│   ├── orig_attention.py
│   ├── positional_encoder.py
│   ├── resnet.py
│   ├── unet.py
│   └── unet_3d_blocks.py
├── pipelines/
│   ├── context.py
│   └── pipeline_stage_1.py
├── requirements.txt
├── training_scripts/
│   ├── llava/
│   │   ├── __init__.py
│   │   ├── constants.py
│   │   ├── conversation.py
│   │   ├── eval/
│   │   │   ├── eval_gpt_review.py
│   │   │   ├── eval_gpt_review_bench.py
│   │   │   ├── eval_gpt_review_visual.py
│   │   │   ├── eval_pope.py
│   │   │   ├── eval_science_qa.py
│   │   │   ├── eval_science_qa_gpt4.py
│   │   │   ├── eval_science_qa_gpt4_requery.py
│   │   │   ├── eval_textvqa.py
│   │   │   ├── generate_webpage_data_from_table.py
│   │   │   ├── m4c_evaluator.py
│   │   │   ├── model_qa.py
│   │   │   ├── model_vqa.py
│   │   │   ├── model_vqa_loader.py
│   │   │   ├── model_vqa_mmbench.py
│   │   │   ├── model_vqa_science.py
│   │   │   ├── qa_baseline_gpt35.py
│   │   │   ├── run_llava.py
│   │   │   ├── summarize_gpt_review.py
│   │   │   ├── table/
│   │   │   │   ├── answer/
│   │   │   │   │   ├── answer_alpaca-13b.jsonl
│   │   │   │   │   ├── answer_bard.jsonl
│   │   │   │   │   ├── answer_gpt35.jsonl
│   │   │   │   │   ├── answer_llama-13b.jsonl
│   │   │   │   │   └── answer_vicuna-13b.jsonl
│   │   │   │   ├── caps_boxes_coco2014_val_80.jsonl
│   │   │   │   ├── model.jsonl
│   │   │   │   ├── prompt.jsonl
│   │   │   │   ├── question.jsonl
│   │   │   │   ├── results/
│   │   │   │   │   ├── test_sqa_llava_13b_v0.json
│   │   │   │   │   └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│   │   │   │   ├── review/
│   │   │   │   │   ├── review_alpaca-13b_vicuna-13b.jsonl
│   │   │   │   │   ├── review_bard_vicuna-13b.jsonl
│   │   │   │   │   ├── review_gpt35_vicuna-13b.jsonl
│   │   │   │   │   └── review_llama-13b_vicuna-13b.jsonl
│   │   │   │   ├── reviewer.jsonl
│   │   │   │   └── rule.json
│   │   │   └── webpage/
│   │   │       ├── index.html
│   │   │       ├── script.js
│   │   │       └── styles.css
│   │   ├── mm_utils.py
│   │   ├── model/
│   │   │   ├── __init__.py
│   │   │   ├── apply_delta.py
│   │   │   ├── builder.py
│   │   │   ├── consolidate.py
│   │   │   ├── language_model/
│   │   │   │   ├── llava_llama.py
│   │   │   │   ├── llava_mistral.py
│   │   │   │   └── llava_mpt.py
│   │   │   ├── llava_arch.py
│   │   │   ├── make_delta.py
│   │   │   ├── multimodal_encoder/
│   │   │   │   ├── builder.py
│   │   │   │   └── clip_encoder.py
│   │   │   ├── multimodal_projector/
│   │   │   │   └── builder.py
│   │   │   └── utils.py
│   │   ├── serve/
│   │   │   ├── __init__.py
│   │   │   ├── cli.py
│   │   │   ├── controller.py
│   │   │   ├── gradio_web_server.py
│   │   │   ├── model_worker.py
│   │   │   ├── register_worker.py
│   │   │   ├── sglang_worker.py
│   │   │   └── test_message.py
│   │   ├── train/
│   │   │   ├── llama_flash_attn_monkey_patch.py
│   │   │   ├── llama_xformers_attn_monkey_patch.py
│   │   │   ├── llava_trainer.py
│   │   │   ├── train.py
│   │   │   ├── train_mem.py
│   │   │   └── train_xformers.py
│   │   └── utils.py
│   ├── merge_ckpt.sh
│   ├── scripts/
│   │   ├── merge_lora_weights.py
│   │   └── zero2.json
│   ├── train_mask_generator.py
│   ├── train_renderer.py
│   └── train_text_generator.sh
├── unet_2d/
│   ├── attention.py
│   ├── resnet.py
│   ├── unet_2d_blocks.py
│   └── unet_2d_condition.py
└── utils/
    ├── __init__.py
    ├── dist_tools.py
    ├── inference_helpers.py
    ├── llava_utils.py
    ├── text_wrapper.py
    └── util.py
Download .txt
SYMBOL INDEX (695 symbols across 77 files)

FILE: data_processing/run_llava/main.py
  function main (line 18) | def main(args):

FILE: data_processing/run_llava/make_list.py
  function get_all_files (line 10) | def get_all_files(src_dir, extension="*.json"):

FILE: data_processing/run_llava/utils.py
  function image_parser (line 30) | def image_parser(args):
  function load_image (line 35) | def load_image(image_file):
  function load_images (line 44) | def load_images(image_files):
  class Predictor (line 52) | class Predictor:
    method __init__ (line 53) | def __init__(self, args) -> None:
    method set_args (line 70) | def set_args(self, args):
    method eval_model (line 73) | def eval_model(self):

FILE: dataset/dataset.py
  function zero_rank_print (line 16) | def zero_rank_print(s):
  function load_im_as_tensor (line 19) | def load_im_as_tensor(im_paths):
  class InvPaintingDataset (line 38) | class InvPaintingDataset(Dataset):
    method __init__ (line 39) | def __init__(
    method __len__ (line 82) | def __len__(self):
    method get_batch (line 85) | def get_batch(self,idx):
    method __getitem__ (line 254) | def __getitem__(self, idx):
  function collate_fn (line 310) | def collate_fn(data):

FILE: demo.py
  function parse_args (line 21) | def parse_args():
  function main (line 52) | def main(args):

FILE: models/ReferenceEncoder.py
  class ReferenceEncoder (line 11) | class ReferenceEncoder(nn.Module):
    method __init__ (line 12) | def __init__(self, model_path="openai/clip-vit-base-patch32"):
    method freeze (line 17) | def freeze(self):
    method forward (line 22) | def forward(self, pixel_values):

FILE: models/ReferenceNet.py
  class Identity (line 57) | class Identity(torch.nn.Module):
    method __init__ (line 77) | def __init__(self, scale=None, *args, **kwargs) -> None:
    method forward (line 80) | def forward(self, input, *args, **kwargs):
  class _LoRACompatibleLinear (line 85) | class _LoRACompatibleLinear(nn.Module):
    method __init__ (line 90) | def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None...
    method set_lora_layer (line 94) | def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
    method _fuse_lora (line 97) | def _fuse_lora(self):
    method _unfuse_lora (line 100) | def _unfuse_lora(self):
    method forward (line 103) | def forward(self, hidden_states, scale=None, lora_scale: int = 1):
  class UNet2DConditionOutput (line 108) | class UNet2DConditionOutput(BaseOutput):
  class ReferenceNet (line 120) | class ReferenceNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    method __init__ (line 211) | def __init__(
    method attn_processors (line 653) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 676) | def set_attn_processor(
    method set_default_attn_processor (line 712) | def set_default_attn_processor(self):
    method set_attention_slice (line 727) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 792) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 796) | def forward(
    method load_referencenet (line 1091) | def load_referencenet(cls, pretrained_model_path):

FILE: models/ReferenceNet_attention.py
  function torch_dfs (line 12) | def torch_dfs(model: torch.nn.Module):
  class ReferenceNetAttention (line 19) | class ReferenceNetAttention():
    method __init__ (line 21) | def __init__(self,
    method register_reference_hooks (line 54) | def register_reference_hooks(
    method update (line 270) | def update(self, writer, dtype=torch.float32):
    method clear (line 294) | def clear(self):

FILE: models/ReferenceNet_attention_fp16.py
  function torch_dfs (line 12) | def torch_dfs(model: torch.nn.Module):
  class ReferenceNetAttention (line 19) | class ReferenceNetAttention():
    method __init__ (line 21) | def __init__(self,
    method register_reference_hooks (line 51) | def register_reference_hooks(
    method update (line 232) | def update(self, writer, dtype=torch.float16):
    method clear (line 254) | def clear(self):

FILE: models/attention.py
  class Transformer3DModelOutput (line 37) | class Transformer3DModelOutput(BaseOutput):
  class Transformer3DModel (line 48) | class Transformer3DModel(ModelMixin, ConfigMixin):
    method __init__ (line 50) | def __init__(
    method forward (line 112) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
  class BasicTransformerBlock (line 164) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 165) | def __init__(
    method set_use_memory_efficient_attention_xformers (line 248) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
    method forward (line 276) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...

FILE: models/clip_adapter.py
  class NextImageFeaturePredictor (line 6) | class NextImageFeaturePredictor(nn.Module):
    method __init__ (line 7) | def __init__(self, input_feature_dim=768, output_feature_dim=768, hidd...
    method forward (line 18) | def forward(self, current_img_feature, final_img_feature):

FILE: models/hack_cur_image_guider.py
  class Hack_CurImageGuider (line 8) | class Hack_CurImageGuider(nn.Module):
    method __init__ (line 9) | def __init__(self, in_channels=3, noise_latent_channels=320):
    method _initialize_weights (line 54) | def _initialize_weights(self):
    method forward (line 69) | def forward(self, x):
    method from_pretrained (line 76) | def from_pretrained(cls,pretrained_model_path, in_channels=3):

FILE: models/hack_unet2d.py
  class Hack_UNet2DConditionModel (line 11) | class Hack_UNet2DConditionModel(UNet2DConditionModel):
    method forward (line 12) | def forward(

FILE: models/image_processor.py
  class LPIPS_Image_Processor (line 8) | class LPIPS_Image_Processor():
    method __init__ (line 9) | def __init__(self):
    method process (line 12) | def process(self, img):
  class Seg_Image_Processor (line 22) | class Seg_Image_Processor():
    method __init__ (line 23) | def __init__(self):
    method process (line 28) | def process(self, img):

FILE: models/orig_attention.py
  class Transformer2DModelOutput (line 36) | class Transformer2DModelOutput(BaseOutput):
  class Transformer2DModel (line 54) | class Transformer2DModel(ModelMixin, ConfigMixin):
    method __init__ (line 93) | def __init__(
    method forward (line 184) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
  class AttentionBlock (line 253) | class AttentionBlock(nn.Module):
    method __init__ (line 271) | def __init__(
    method set_use_memory_efficient_attention_xformers (line 296) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
    method reshape_heads_to_batch_dim (line 320) | def reshape_heads_to_batch_dim(self, tensor):
    method reshape_batch_dim_to_heads (line 327) | def reshape_batch_dim_to_heads(self, tensor):
    method forward (line 334) | def forward(self, hidden_states):
  class BasicTransformerBlock (line 388) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 405) | def __init__(
    method set_use_memory_efficient_attention_xformers (line 458) | def set_use_memory_efficient_attention_xformers(self, use_memory_effic...
    method forward (line 485) | def forward(self, hidden_states, encoder_hidden_states=None, timestep=...
  class CrossAttention (line 516) | class CrossAttention(nn.Module):
    method __init__ (line 531) | def __init__(
    method reshape_heads_to_batch_dim (line 578) | def reshape_heads_to_batch_dim(self, tensor):
    method reshape_batch_dim_to_heads (line 585) | def reshape_batch_dim_to_heads(self, tensor):
    method set_attention_slice (line 592) | def set_attention_slice(self, slice_size):
    method forward (line 598) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
    method _attention (line 655) | def _attention(self, query, key, value, attention_mask=None):
    method _sliced_attention (line 686) | def _sliced_attention(self, query, key, value, sequence_length, dim, a...
    method _memory_efficient_attention_xformers (line 729) | def _memory_efficient_attention_xformers(self, query, key, value, atte...
  class FeedForward (line 739) | class FeedForward(nn.Module):
    method __init__ (line 751) | def __init__(
    method forward (line 778) | def forward(self, hidden_states):
  class GELU (line 784) | class GELU(nn.Module):
    method __init__ (line 789) | def __init__(self, dim_in: int, dim_out: int):
    method gelu (line 793) | def gelu(self, gate):
    method forward (line 799) | def forward(self, hidden_states):
  class GEGLU (line 806) | class GEGLU(nn.Module):
    method __init__ (line 815) | def __init__(self, dim_in: int, dim_out: int):
    method gelu (line 819) | def gelu(self, gate):
    method forward (line 825) | def forward(self, hidden_states):
  class ApproximateGELU (line 830) | class ApproximateGELU(nn.Module):
    method __init__ (line 837) | def __init__(self, dim_in: int, dim_out: int):
    method forward (line 841) | def forward(self, x):
  class AdaLayerNorm (line 846) | class AdaLayerNorm(nn.Module):
    method __init__ (line 851) | def __init__(self, embedding_dim, num_embeddings):
    method forward (line 858) | def forward(self, x, timestep):
  class DualTransformer2DModel (line 865) | class DualTransformer2DModel(nn.Module):
    method __init__ (line 892) | def __init__(
    method forward (line 941) | def forward(

FILE: models/positional_encoder.py
  class Embedder (line 17) | class Embedder:
    method __init__ (line 18) | def __init__(self, **kwargs):
    method create_embedding_fn (line 22) | def create_embedding_fn(self):
    method embed (line 46) | def embed(self, inputs):
  function get_embedder (line 50) | def get_embedder(multires, i=0):
  class PositionalEncoder (line 72) | class PositionalEncoder(nn.Module):
    method __init__ (line 73) | def __init__(self,in_features=42 ):
    method forward (line 86) | def forward(self, x):

FILE: models/resnet.py
  class InflatedConv3d (line 30) | class InflatedConv3d(nn.Conv2d):
    method forward (line 31) | def forward(self, x):
  class Upsample3D (line 41) | class Upsample3D(nn.Module):
    method __init__ (line 42) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
    method forward (line 56) | def forward(self, hidden_states, output_size=None):
  class Downsample3D (line 87) | class Downsample3D(nn.Module):
    method __init__ (line 88) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
    method forward (line 102) | def forward(self, hidden_states):
  class ResnetBlock3D (line 113) | class ResnetBlock3D(nn.Module):
    method __init__ (line 114) | def __init__(
    method forward (line 177) | def forward(self, input_tensor, temb):
  class Mish (line 210) | class Mish(torch.nn.Module):
    method forward (line 211) | def forward(self, hidden_states):

FILE: models/unet.py
  class UNet3DConditionOutput (line 32) | class UNet3DConditionOutput(BaseOutput):
  class UNet3DConditionModel (line 36) | class UNet3DConditionModel(ModelMixin, ConfigMixin):
    method __init__ (line 40) | def __init__(
    method set_attention_slice (line 241) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 306) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 310) | def forward(
    method from_pretrained_2d (line 449) | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, une...

FILE: models/unet_3d_blocks.py
  function get_down_block (line 9) | def get_down_block(
  function get_up_block (line 85) | def get_up_block(
  class UNetMidBlock3DCrossAttn (line 160) | class UNetMidBlock3DCrossAttn(nn.Module):
    method __init__ (line 161) | def __init__(
    method forward (line 255) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None...
  class CrossAttnDownBlock3D (line 265) | class CrossAttnDownBlock3D(nn.Module):
    method __init__ (line 266) | def __init__(
    method forward (line 363) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None...
  class DownBlock3D (line 404) | class DownBlock3D(nn.Module):
    method __init__ (line 405) | def __init__(
    method forward (line 469) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
  class CrossAttnUpBlock3D (line 500) | class CrossAttnUpBlock3D(nn.Module):
    method __init__ (line 501) | def __init__(
    method forward (line 594) | def forward(
  class UpBlock3D (line 643) | class UpBlock3D(nn.Module):
    method __init__ (line 644) | def __init__(
    method forward (line 704) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...

FILE: pipelines/context.py
  function ordered_halving (line 12) | def ordered_halving(val):
  function uniform (line 20) | def uniform(
  function get_context_scheduler (line 45) | def get_context_scheduler(name: str) -> Callable:
  function get_total_steps (line 52) | def get_total_steps(

FILE: pipelines/pipeline_stage_1.py
  class InvPaintingPipelineOutput (line 50) | class InvPaintingPipelineOutput(BaseOutput):
  function retrieve_latents (line 56) | def retrieve_latents(
  class InvPaintingPipeline (line 71) | class InvPaintingPipeline(DiffusionPipeline):
    method __init__ (line 74) | def __init__(
    method _execution_device (line 159) | def _execution_device(self):
    method decode_latents (line 172) | def decode_latents(self, latents, rank, decoder_consistency=None):
    method prepare_extra_step_kwargs (line 183) | def prepare_extra_step_kwargs(self, generator, eta):
    method _encode_vae_image (line 202) | def _encode_vae_image(self, image: torch.Tensor, generator: torch.Gene...
    method prepare_latents (line 217) | def prepare_latents(
    method images2latents (line 278) | def images2latents(self, images, dtype):
    method get_timesteps (line 294) | def get_timesteps(self, num_inference_steps, strength, device):
    method __call__ (line 306) | def __call__(

FILE: training_scripts/llava/conversation.py
  class SeparatorStyle (line 9) | class SeparatorStyle(Enum):
  class Conversation (line 19) | class Conversation:
    method get_prompt (line 32) | def get_prompt(self):
    method append_message (line 109) | def append_message(self, role, message):
    method process_image (line 112) | def process_image(self, image, image_process_mode, return_pil=False, i...
    method get_images (line 152) | def get_images(self, return_pil=False):
    method to_gradio_chatbot (line 162) | def to_gradio_chatbot(self):
    method copy (line 180) | def copy(self):
    method dict (line 191) | def dict(self):

FILE: training_scripts/llava/eval/eval_gpt_review.py
  function get_eval (line 13) | def get_eval(content: str, max_tokens: int):
  function parse_score (line 39) | def parse_score(review):

FILE: training_scripts/llava/eval/eval_gpt_review_bench.py
  function get_eval (line 11) | def get_eval(content: str, max_tokens: int):
  function parse_score (line 36) | def parse_score(review):

FILE: training_scripts/llava/eval/eval_gpt_review_visual.py
  function get_eval (line 11) | def get_eval(content: str, max_tokens: int):
  function parse_score (line 36) | def parse_score(review):

FILE: training_scripts/llava/eval/eval_pope.py
  function eval_pope (line 5) | def eval_pope(answers, label_file):

FILE: training_scripts/llava/eval/eval_science_qa.py
  function get_args (line 8) | def get_args():
  function convert_caps (line 19) | def convert_caps(results):
  function get_pred_idx (line 28) | def get_pred_idx(prediction, choices, options):

FILE: training_scripts/llava/eval/eval_science_qa_gpt4.py
  function get_args (line 9) | def get_args():
  function convert_caps (line 19) | def convert_caps(results):
  function get_pred_idx (line 28) | def get_pred_idx(prediction, choices, options):

FILE: training_scripts/llava/eval/eval_science_qa_gpt4_requery.py
  function get_args (line 9) | def get_args():
  function convert_caps (line 21) | def convert_caps(results):
  function get_pred_idx (line 30) | def get_pred_idx(prediction, choices, options):

FILE: training_scripts/llava/eval/eval_textvqa.py
  function get_args (line 9) | def get_args():
  function prompt_processor (line 17) | def prompt_processor(prompt):
  function eval_single (line 35) | def eval_single(annotation_file, result_file):

FILE: training_scripts/llava/eval/generate_webpage_data_from_table.py
  function read_jsonl (line 10) | def read_jsonl(path: str, key: str=None):
  function trim_hanging_lines (line 23) | def trim_hanging_lines(s: str, n: int) -> str:

FILE: training_scripts/llava/eval/m4c_evaluator.py
  class EvalAIAnswerProcessor (line 7) | class EvalAIAnswerProcessor:
    method __init__ (line 178) | def __init__(self, *args, **kwargs):
    method word_tokenize (line 181) | def word_tokenize(self, word):
    method process_punctuation (line 186) | def process_punctuation(self, in_text):
    method process_digit_article (line 198) | def process_digit_article(self, in_text):
    method __call__ (line 213) | def __call__(self, item):
  class TextVQAAccuracyEvaluator (line 221) | class TextVQAAccuracyEvaluator:
    method __init__ (line 222) | def __init__(self):
    method _compute_answer_scores (line 225) | def _compute_answer_scores(self, raw_answers):
    method eval_pred_list (line 248) | def eval_pred_list(self, pred_list):
  class STVQAAccuracyEvaluator (line 260) | class STVQAAccuracyEvaluator:
    method __init__ (line 261) | def __init__(self):
    method eval_pred_list (line 264) | def eval_pred_list(self, pred_list):
  class STVQAANLSEvaluator (line 276) | class STVQAANLSEvaluator:
    method __init__ (line 277) | def __init__(self):
    method get_anls (line 282) | def get_anls(self, s1, s2):
    method eval_pred_list (line 289) | def eval_pred_list(self, pred_list):
  class TextCapsBleu4Evaluator (line 301) | class TextCapsBleu4Evaluator:
    method __init__ (line 302) | def __init__(self):
    method eval_pred_list (line 321) | def eval_pred_list(self, pred_list):

FILE: training_scripts/llava/eval/model_qa.py
  function eval_model (line 14) | def eval_model(model_name, questions_file, answers_file):

FILE: training_scripts/llava/eval/model_vqa.py
  function split_list (line 18) | def split_list(lst, n):
  function get_chunk (line 24) | def get_chunk(lst, n, k):
  function eval_model (line 29) | def eval_model(args):

FILE: training_scripts/llava/eval/model_vqa_loader.py
  function split_list (line 19) | def split_list(lst, n):
  function get_chunk (line 25) | def get_chunk(lst, n, k):
  class CustomDataset (line 31) | class CustomDataset(Dataset):
    method __init__ (line 32) | def __init__(self, questions, image_folder, tokenizer, image_processor...
    method __getitem__ (line 39) | def __getitem__(self, index):
    method __len__ (line 60) | def __len__(self):
  function collate_fn (line 64) | def collate_fn(batch):
  function create_data_loader (line 72) | def create_data_loader(questions, image_folder, tokenizer, image_process...
  function eval_model (line 79) | def eval_model(args):

FILE: training_scripts/llava/eval/model_vqa_mmbench.py
  function split_list (line 22) | def split_list(lst, n):
  function get_chunk (line 28) | def get_chunk(lst, n, k):
  function is_none (line 33) | def is_none(value):
  function get_options (line 44) | def get_options(row, options):
  function eval_model (line 54) | def eval_model(args):

FILE: training_scripts/llava/eval/model_vqa_science.py
  function split_list (line 18) | def split_list(lst, n):
  function get_chunk (line 24) | def get_chunk(lst, n, k):
  function eval_model (line 29) | def eval_model(args):

FILE: training_scripts/llava/eval/qa_baseline_gpt35.py
  function get_answer (line 16) | def get_answer(question_id: int, question: str, max_tokens: int):

FILE: training_scripts/llava/eval/run_llava.py
  function image_parser (line 28) | def image_parser(args):
  function load_image (line 33) | def load_image(image_file):
  function load_images (line 42) | def load_images(image_files):
  function eval_model (line 50) | def eval_model(args):

FILE: training_scripts/llava/eval/summarize_gpt_review.py
  function parse_args (line 9) | def parse_args():

FILE: training_scripts/llava/eval/webpage/script.js
  function text2Markdown (line 35) | function text2Markdown(text) {
  function capitalizeFirstChar (line 41) | function capitalizeFirstChar(str) {
  function updateQuestionSelect (line 48) | function updateQuestionSelect(question_id) {
  function updateModelSelect (line 64) | function updateModelSelect() {
  function populateModels (line 70) | function populateModels(models) {
  function populateQuestions (line 81) | function populateQuestions(questions) {
  function displayQuestion (line 110) | function displayQuestion(index) {
  function displayAnswers (line 116) | function displayAnswers(index) {
  function switchQuestionAndCategory (line 203) | function switchQuestionAndCategory() {
  function updateExpandButtonVisibility (line 226) | function updateExpandButtonVisibility(card) {

FILE: training_scripts/llava/mm_utils.py
  function select_best_resolution (line 12) | def select_best_resolution(original_size, possible_resolutions):
  function resize_and_pad_image (line 42) | def resize_and_pad_image(image, target_resolution):
  function divide_to_patches (line 77) | def divide_to_patches(image, patch_size):
  function get_anyres_image_grid_shape (line 99) | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
  function process_anyres_image (line 122) | def process_anyres_image(image, processor, grid_pinpoints):
  function load_image_from_base64 (line 150) | def load_image_from_base64(image):
  function expand2square (line 154) | def expand2square(pil_img, background_color):
  function process_images (line 168) | def process_images(images, image_processor, model_cfg):
  function tokenizer_image_token (line 188) | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOK...
  function get_model_name_from_path (line 210) | def get_model_name_from_path(model_path):
  class KeywordsStoppingCriteria (line 218) | class KeywordsStoppingCriteria(StoppingCriteria):
    method __init__ (line 219) | def __init__(self, keywords, tokenizer, input_ids):
    method call_for_batch (line 233) | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.F...
    method __call__ (line 246) | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTe...

FILE: training_scripts/llava/model/apply_delta.py
  function apply_delta (line 13) | def apply_delta(base_model_path, target_model_path, delta_path):

FILE: training_scripts/llava/model/builder.py
  function load_pretrained_model (line 26) | def load_pretrained_model(model_path, model_base, model_name, load_8bit=...

FILE: training_scripts/llava/model/consolidate.py
  function consolidate_ckpt (line 13) | def consolidate_ckpt(src_path, dst_path):

FILE: training_scripts/llava/model/language_model/llava_llama.py
  class LlavaConfig (line 30) | class LlavaConfig(LlamaConfig):
  class LlavaLlamaModel (line 34) | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    method __init__ (line 37) | def __init__(self, config: LlamaConfig):
  class LlavaLlamaForCausalLM (line 41) | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    method __init__ (line 44) | def __init__(self, config):
    method get_model (line 54) | def get_model(self):
    method forward (line 57) | def forward(
    method generate (line 105) | def generate(
    method prepare_inputs_for_generation (line 145) | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,

FILE: training_scripts/llava/model/language_model/llava_mistral.py
  class LlavaMistralConfig (line 31) | class LlavaMistralConfig(MistralConfig):
  class LlavaMistralModel (line 35) | class LlavaMistralModel(LlavaMetaModel, MistralModel):
    method __init__ (line 38) | def __init__(self, config: MistralConfig):
  class LlavaMistralForCausalLM (line 42) | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
    method __init__ (line 45) | def __init__(self, config):
    method get_model (line 54) | def get_model(self):
    method forward (line 57) | def forward(
    method generate (line 105) | def generate(
    method prepare_inputs_for_generation (line 144) | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,

FILE: training_scripts/llava/model/language_model/llava_mpt.py
  class LlavaMptConfig (line 25) | class LlavaMptConfig(MptConfig):
  class LlavaMptModel (line 29) | class LlavaMptModel(LlavaMetaModel, MptModel):
    method __init__ (line 32) | def __init__(self, config: MptConfig):
    method embed_tokens (line 36) | def embed_tokens(self, x):
  class LlavaMptForCausalLM (line 40) | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
    method __init__ (line 44) | def __init__(self, config):
    method get_model (line 53) | def get_model(self):
    method _set_gradient_checkpointing (line 56) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 60) | def forward(
    method prepare_inputs_for_generation (line 87) | def prepare_inputs_for_generation(self, input_ids, past_key_values=Non...

FILE: training_scripts/llava/model/llava_arch.py
  class LlavaMetaModel (line 29) | class LlavaMetaModel:
    method __init__ (line 31) | def __init__(self, config):
    method get_vision_tower (line 43) | def get_vision_tower(self):
    method initialize_vision_modules (line 49) | def initialize_vision_modules(self, model_args, fsdp=None):
  function unpad_image (line 100) | def unpad_image(tensor, original_size):
  class LlavaMetaForCausalLM (line 131) | class LlavaMetaForCausalLM(ABC):
    method get_model (line 134) | def get_model(self):
    method get_vision_tower (line 137) | def get_vision_tower(self):
    method encode_images (line 140) | def encode_images(self, images):
    method prepare_inputs_labels_for_multimodal (line 145) | def prepare_inputs_labels_for_multimodal(
    method initialize_vision_tokenizer (line 326) | def initialize_vision_tokenizer(self, model_args, tokenizer):

FILE: training_scripts/llava/model/make_delta.py
  function make_delta (line 13) | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_...

FILE: training_scripts/llava/model/multimodal_encoder/builder.py
  function build_vision_tower (line 5) | def build_vision_tower(vision_tower_cfg, **kwargs):

FILE: training_scripts/llava/model/multimodal_encoder/clip_encoder.py
  class CLIPVisionTower (line 7) | class CLIPVisionTower(nn.Module):
    method __init__ (line 8) | def __init__(self, vision_tower, args, delay_load=False):
    method load_model (line 24) | def load_model(self, device_map=None):
    method feature_select (line 35) | def feature_select(self, image_forward_outs):
    method forward (line 46) | def forward(self, images):
    method dummy_feature (line 60) | def dummy_feature(self):
    method dtype (line 64) | def dtype(self):
    method device (line 68) | def device(self):
    method config (line 72) | def config(self):
    method hidden_size (line 79) | def hidden_size(self):
    method num_patches_per_side (line 83) | def num_patches_per_side(self):
    method num_patches (line 87) | def num_patches(self):

FILE: training_scripts/llava/model/multimodal_projector/builder.py
  class IdentityMap (line 6) | class IdentityMap(nn.Module):
    method __init__ (line 7) | def __init__(self):
    method forward (line 10) | def forward(self, x, *args, **kwargs):
    method config (line 14) | def config(self):
  class SimpleResBlock (line 18) | class SimpleResBlock(nn.Module):
    method __init__ (line 19) | def __init__(self, channels):
    method forward (line 28) | def forward(self, x):
  function build_vision_projector (line 33) | def build_vision_projector(config, delay_load=False, **kwargs):

FILE: training_scripts/llava/model/utils.py
  function auto_upgrade (line 4) | def auto_upgrade(config):

FILE: training_scripts/llava/serve/cli.py
  function load_image (line 18) | def load_image(image_file):
  function main (line 27) | def main(args):

FILE: training_scripts/llava/serve/controller.py
  class DispatchMethod (line 28) | class DispatchMethod(Enum):
    method from_str (line 33) | def from_str(cls, name):
  class WorkerInfo (line 43) | class WorkerInfo:
  function heart_beat_controller (line 51) | def heart_beat_controller(controller):
  class Controller (line 57) | class Controller:
    method __init__ (line 58) | def __init__(self, dispatch_method: str):
    method register_worker (line 69) | def register_worker(self, worker_name: str, check_heart_beat: bool,
    method get_worker_status (line 88) | def get_worker_status(self, worker_name: str):
    method remove_worker (line 101) | def remove_worker(self, worker_name: str):
    method refresh_all_workers (line 104) | def refresh_all_workers(self):
    method list_models (line 112) | def list_models(self):
    method get_worker_address (line 120) | def get_worker_address(self, model_name: str):
    method receive_heart_beat (line 173) | def receive_heart_beat(self, worker_name: str, queue_length: int):
    method remove_stable_workers_by_expiration (line 183) | def remove_stable_workers_by_expiration(self):
    method worker_api_generate_stream (line 193) | def worker_api_generate_stream(self, params):
    method worker_api_get_status (line 220) | def worker_api_get_status(self):
  function register_worker (line 243) | async def register_worker(request: Request):
  function refresh_all_workers (line 251) | async def refresh_all_workers():
  function list_models (line 256) | async def list_models():
  function get_worker_address (line 262) | async def get_worker_address(request: Request):
  function receive_heart_beat (line 269) | async def receive_heart_beat(request: Request):
  function worker_api_generate_stream (line 277) | async def worker_api_generate_stream(request: Request):
  function worker_api_get_status (line 284) | async def worker_api_get_status(request: Request):

FILE: training_scripts/llava/serve/gradio_web_server.py
  function get_conv_log_filename (line 32) | def get_conv_log_filename():
  function get_model_list (line 38) | def get_model_list():
  function load_demo (line 58) | def load_demo(url_params, request: gr.Request):
  function load_demo_refresh_model_list (line 71) | def load_demo_refresh_model_list(request: gr.Request):
  function vote_last_response (line 82) | def vote_last_response(state, vote_type, model_selector, request: gr.Req...
  function upvote_last_response (line 94) | def upvote_last_response(state, model_selector, request: gr.Request):
  function downvote_last_response (line 100) | def downvote_last_response(state, model_selector, request: gr.Request):
  function flag_last_response (line 106) | def flag_last_response(state, model_selector, request: gr.Request):
  function regenerate (line 112) | def regenerate(state, image_process_mode, request: gr.Request):
  function clear_history (line 122) | def clear_history(request: gr.Request):
  function add_text (line 128) | def add_text(state, text, image, image_process_mode, request: gr.Request):
  function http_bot (line 154) | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, ...
  function build_demo (line 315) | def build_demo(embed_mode, cur_dir=None, concurrency_count=10):

FILE: training_scripts/llava/serve/model_worker.py
  function heart_beat_worker (line 37) | def heart_beat_worker(controller):
  class ModelWorker (line 44) | class ModelWorker:
    method __init__ (line 45) | def __init__(self, controller_addr, worker_addr,
    method register_to_controller (line 75) | def register_to_controller(self):
    method send_heart_beat (line 87) | def send_heart_beat(self):
    method get_queue_length (line 108) | def get_queue_length(self):
    method get_status (line 115) | def get_status(self):
    method generate_stream (line 123) | def generate_stream(self, params):
    method generate_stream_gate (line 195) | def generate_stream_gate(self, params):
  function release_model_semaphore (line 225) | def release_model_semaphore(fn=None):
  function generate_stream (line 232) | async def generate_stream(request: Request):
  function get_status (line 248) | async def get_status(request: Request):

FILE: training_scripts/llava/serve/sglang_worker.py
  function heart_beat_worker (line 38) | def heart_beat_worker(controller):
  function pipeline (line 45) | def pipeline(s, prompt, max_tokens):
  class ModelWorker (line 54) | class ModelWorker:
    method __init__ (line 55) | def __init__(self, controller_addr, worker_addr, sgl_endpoint,
    method register_to_controller (line 85) | def register_to_controller(self):
    method send_heart_beat (line 97) | def send_heart_beat(self):
    method get_queue_length (line 118) | def get_queue_length(self):
    method get_status (line 125) | def get_status(self):
    method generate_stream (line 132) | async def generate_stream(self, params):
    method generate_stream_gate (line 172) | async def generate_stream_gate(self, params):
  function release_model_semaphore (line 195) | def release_model_semaphore(fn=None):
  function generate_stream (line 202) | async def generate_stream(request: Request):
  function get_status (line 218) | async def get_status(request: Request):

FILE: training_scripts/llava/serve/test_message.py
  function main (line 9) | def main():

FILE: training_scripts/llava/train/llama_flash_attn_monkey_patch.py
  function forward (line 16) | def forward(
  function _prepare_decoder_attention_mask (line 98) | def _prepare_decoder_attention_mask(
  function replace_llama_attn_with_flash_attn (line 105) | def replace_llama_attn_with_flash_attn():

FILE: training_scripts/llava/train/llama_xformers_attn_monkey_patch.py
  function replace_llama_attn_with_xformers_attn (line 19) | def replace_llama_attn_with_xformers_attn():
  function xformers_forward (line 23) | def xformers_forward(

FILE: training_scripts/llava/train/llava_trainer.py
  function maybe_zero_3 (line 18) | def maybe_zero_3(param, ignore_status=False, name=None):
  function get_mm_adapter_state_maybe_zero_3 (line 32) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
  function split_to_even_chunks (line 38) | def split_to_even_chunks(indices, lengths, num_chunks):
  function get_modality_length_grouped_indices (line 60) | def get_modality_length_grouped_indices(lengths, batch_size, world_size,...
  function get_length_grouped_indices (line 88) | def get_length_grouped_indices(lengths, batch_size, world_size, generato...
  class LengthGroupedSampler (line 99) | class LengthGroupedSampler(Sampler):
    method __init__ (line 105) | def __init__(
    method __len__ (line 122) | def __len__(self):
    method __iter__ (line 125) | def __iter__(self):
  class LLaVATrainer (line 133) | class LLaVATrainer(Trainer):
    method _get_train_sampler (line 135) | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
    method create_optimizer (line 150) | def create_optimizer(self):
    method _save_checkpoint (line 230) | def _save_checkpoint(self, model, trial, metrics=None):
    method _save (line 251) | def _save(self, output_dir: Optional[str] = None, state_dict=None):

FILE: training_scripts/llava/train/train.py
  function rank0_print (line 44) | def rank0_print(*args):
  class ModelArguments (line 54) | class ModelArguments:
  class DataArguments (line 70) | class DataArguments:
  class TrainingArguments (line 80) | class TrainingArguments(transformers.TrainingArguments):
  function maybe_zero_3 (line 115) | def maybe_zero_3(param, ignore_status=False, name=None):
  function get_peft_state_maybe_zero_3 (line 132) | def get_peft_state_maybe_zero_3(named_params, bias):
  function get_peft_state_non_lora_maybe_zero_3 (line 157) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
  function get_mm_adapter_state_maybe_zero_3 (line 165) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
  function find_all_linear_names (line 171) | def find_all_linear_names(model):
  function safe_save_model_for_hf_trainer (line 187) | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
  function smart_tokenizer_and_embedding_resize (line 226) | def smart_tokenizer_and_embedding_resize(
  function _tokenize_fn (line 251) | def _tokenize_fn(strings: Sequence[str],
  function _mask_targets (line 278) | def _mask_targets(target, tokenized_lens, speakers):
  function _add_speaker_and_signal (line 289) | def _add_speaker_and_signal(header, source, get_conversation=True):
  function preprocess_multimodal (line 310) | def preprocess_multimodal(
  function preprocess_llama_2 (line 334) | def preprocess_llama_2(
  function preprocess_v1 (line 416) | def preprocess_v1(
  function preprocess_mpt (line 502) | def preprocess_mpt(
  function preprocess_plain (line 590) | def preprocess_plain(
  function preprocess (line 612) | def preprocess(
  class LazySupervisedDataset (line 660) | class LazySupervisedDataset(Dataset):
    method __init__ (line 663) | def __init__(self, data_path: str,
    method __len__ (line 674) | def __len__(self):
    method lengths (line 678) | def lengths(self):
    method modality_lengths (line 686) | def modality_lengths(self):
    method __getitem__ (line 694) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
  class DataCollatorForSupervisedDataset (line 754) | class DataCollatorForSupervisedDataset(object):
    method __call__ (line 759) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
  function make_supervised_data_module (line 787) | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokeni...
  function train (line 799) | def train(attn_implementation=None):

FILE: training_scripts/llava/utils.py
  function build_logger (line 17) | def build_logger(logger_name, logger_filename):
  class StreamToLogger (line 60) | class StreamToLogger(object):
    method __init__ (line 64) | def __init__(self, logger, log_level=logging.INFO):
    method __getattr__ (line 70) | def __getattr__(self, attr):
    method write (line 73) | def write(self, buf):
    method flush (line 87) | def flush(self):
  function disable_torch_init (line 93) | def disable_torch_init():
  function violates_moderation (line 102) | def violates_moderation(text):
  function pretty_print_semaphore (line 123) | def pretty_print_semaphore(semaphore):

FILE: training_scripts/scripts/merge_lora_weights.py
  function merge_lora (line 6) | def merge_lora(args):

FILE: training_scripts/train_mask_generator.py
  function init_dist (line 64) | def init_dist(launcher="slurm", backend='nccl', port=28888, **kwargs):
  function get_parameters_without_gradients (line 97) | def get_parameters_without_gradients(model):
  function main (line 115) | def main(

FILE: training_scripts/train_renderer.py
  function init_dist (line 55) | def init_dist(launcher="slurm", backend='nccl', port=28888, **kwargs):
  function get_parameters_without_gradients (line 87) | def get_parameters_without_gradients(model):
  function main (line 105) | def main(

FILE: unet_2d/attention.py
  class GatedSelfAttentionDense (line 28) | class GatedSelfAttentionDense(nn.Module):
    method __init__ (line 29) | def __init__(self, query_dim, context_dim, n_heads, d_head):
    method forward (line 46) | def forward(self, x, objs):
  class BasicTransformerBlock (line 60) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 81) | def __init__(
    method set_chunk_feed_forward (line 164) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
    method forward (line 169) | def forward(
  class FeedForward (line 258) | class FeedForward(nn.Module):
    method __init__ (line 271) | def __init__(
    method forward (line 304) | def forward(self, hidden_states, scale: float = 1.0):
  class GELU (line 313) | class GELU(nn.Module):
    method __init__ (line 318) | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
    method gelu (line 323) | def gelu(self, gate):
    method forward (line 329) | def forward(self, hidden_states):
  class GEGLU (line 335) | class GEGLU(nn.Module):
    method __init__ (line 344) | def __init__(self, dim_in: int, dim_out: int):
    method gelu (line 348) | def gelu(self, gate):
    method forward (line 354) | def forward(self, hidden_states, scale: float = 1.0):
  class ApproximateGELU (line 359) | class ApproximateGELU(nn.Module):
    method __init__ (line 366) | def __init__(self, dim_in: int, dim_out: int):
    method forward (line 370) | def forward(self, x):
  class AdaLayerNorm (line 375) | class AdaLayerNorm(nn.Module):
    method __init__ (line 380) | def __init__(self, embedding_dim, num_embeddings):
    method forward (line 387) | def forward(self, x, timestep):
  class AdaLayerNormZero (line 394) | class AdaLayerNormZero(nn.Module):
    method __init__ (line 399) | def __init__(self, embedding_dim, num_embeddings):
    method forward (line 408) | def forward(self, x, timestep, class_labels, hidden_dtype=None):
  class AdaGroupNorm (line 415) | class AdaGroupNorm(nn.Module):
    method __init__ (line 420) | def __init__(
    method forward (line 434) | def forward(self, x, emb):

FILE: unet_2d/resnet.py
  class Upsample1D (line 29) | class Upsample1D(nn.Module):
    method __init__ (line 43) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
    method forward (line 57) | def forward(self, inputs):
  class Downsample1D (line 70) | class Downsample1D(nn.Module):
    method __init__ (line 84) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
    method forward (line 99) | def forward(self, inputs):
  class Upsample2D (line 104) | class Upsample2D(nn.Module):
    method __init__ (line 118) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
    method forward (line 138) | def forward(self, hidden_states, output_size=None, scale: float = 1.0):
  class Downsample2D (line 182) | class Downsample2D(nn.Module):
    method __init__ (line 196) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
    method forward (line 220) | def forward(self, hidden_states, scale: float = 1.0):
  class FirUpsample2D (line 235) | class FirUpsample2D(nn.Module):
    method __init__ (line 249) | def __init__(self, channels=None, out_channels=None, use_conv=False, f...
    method _upsample_2d (line 258) | def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor...
    method forward (line 338) | def forward(self, hidden_states):
  class FirDownsample2D (line 348) | class FirDownsample2D(nn.Module):
    method __init__ (line 362) | def __init__(self, channels=None, out_channels=None, use_conv=False, f...
    method _downsample_2d (line 371) | def _downsample_2d(self, hidden_states, weight=None, kernel=None, fact...
    method forward (line 425) | def forward(self, hidden_states):
  class KDownsample2D (line 436) | class KDownsample2D(nn.Module):
    method __init__ (line 437) | def __init__(self, pad_mode="reflect"):
    method forward (line 444) | def forward(self, inputs):
  class KUpsample2D (line 453) | class KUpsample2D(nn.Module):
    method __init__ (line 454) | def __init__(self, pad_mode="reflect"):
    method forward (line 461) | def forward(self, inputs):
  class ResnetBlock2D (line 470) | class ResnetBlock2D(nn.Module):
    method __init__ (line 501) | def __init__(
    method forward (line 600) | def forward(self, input_tensor, scale: float = 1.0):
  function rearrange_dims (line 663) | def rearrange_dims(tensor):
  class Conv1dBlock (line 674) | class Conv1dBlock(nn.Module):
    method __init__ (line 679) | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
    method forward (line 686) | def forward(self, inputs):
  class ResidualTemporalBlock1D (line 696) | class ResidualTemporalBlock1D(nn.Module):
    method __init__ (line 697) | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
    method forward (line 709) | def forward(self, inputs, t):
  function upsample_2d (line 725) | def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
  function downsample_2d (line 762) | def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
  function upfirdn2d_native (line 797) | def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
  class TemporalConvLayer (line 841) | class TemporalConvLayer(nn.Module):
    method __init__ (line 847) | def __init__(self, in_dim, out_dim=None, dropout=0.0):
    method forward (line 880) | def forward(self, hidden_states, num_frames=1):

FILE: unet_2d/unet_2d_blocks.py
  function get_down_block (line 33) | def get_down_block(
  function get_up_block (line 243) | def get_up_block(
  class AutoencoderTinyBlock (line 456) | class AutoencoderTinyBlock(nn.Module):
    method __init__ (line 457) | def __init__(self, in_channels: int, out_channels: int, act_fn: str):
    method forward (line 474) | def forward(self, x):
  class UNetMidBlock2D (line 478) | class UNetMidBlock2D(nn.Module):
    method __init__ (line 479) | def __init__(
    method forward (line 559) | def forward(self, hidden_states, temb=None):
  class UNetMidBlock2DCrossAttn (line 569) | class UNetMidBlock2DCrossAttn(nn.Module):
    method __init__ (line 570) | def __init__(
    method forward (line 659) | def forward(
  class UNetMidBlock2DSimpleCrossAttn (line 710) | class UNetMidBlock2DSimpleCrossAttn(nn.Module):
    method __init__ (line 711) | def __init__(
    method forward (line 795) | def forward(
  class AttnDownBlock2D (line 834) | class AttnDownBlock2D(nn.Module):
    method __init__ (line 835) | def __init__(
    method forward (line 926) | def forward(self, hidden_states, temb=None, upsample_size=None, cross_...
  class CrossAttnDownBlock2D (line 951) | class CrossAttnDownBlock2D(nn.Module):
    method __init__ (line 952) | def __init__(
    method forward (line 1041) | def forward(
  class DownBlock2D (line 1110) | class DownBlock2D(nn.Module):
    method __init__ (line 1111) | def __init__(
    method forward (line 1162) | def forward(self, hidden_states, temb=None, scale: float = 1.0):
  class DownEncoderBlock2D (line 1196) | class DownEncoderBlock2D(nn.Module):
    method __init__ (line 1197) | def __init__(
    method forward (line 1245) | def forward(self, hidden_states, scale: float = 1.0):
  class AttnDownEncoderBlock2D (line 1256) | class AttnDownEncoderBlock2D(nn.Module):
    method __init__ (line 1257) | def __init__(
    method forward (line 1328) | def forward(self, hidden_states, scale: float = 1.0):
  class AttnSkipDownBlock2D (line 1341) | class AttnSkipDownBlock2D(nn.Module):
    method __init__ (line 1342) | def __init__(
    method forward (line 1422) | def forward(self, hidden_states, temb=None, skip_sample=None, scale: f...
  class SkipDownBlock2D (line 1443) | class SkipDownBlock2D(nn.Module):
    method __init__ (line 1444) | def __init__(
    method forward (line 1503) | def forward(self, hidden_states, temb=None, skip_sample=None, scale: f...
  class ResnetDownsampleBlock2D (line 1522) | class ResnetDownsampleBlock2D(nn.Module):
    method __init__ (line 1523) | def __init__(
    method forward (line 1586) | def forward(self, hidden_states, temb=None, scale: float = 1.0):
  class SimpleCrossAttnDownBlock2D (line 1620) | class SimpleCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1621) | def __init__(
    method forward (line 1715) | def forward(
  class KDownBlock2D (line 1780) | class KDownBlock2D(nn.Module):
    method __init__ (line 1781) | def __init__(
    method forward (line 1826) | def forward(self, hidden_states, temb=None, scale: float = 1.0):
  class KCrossAttnDownBlock2D (line 1858) | class KCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1859) | def __init__(
    method forward (line 1923) | def forward(
  class AttnUpBlock2D (line 1985) | class AttnUpBlock2D(nn.Module):
    method __init__ (line 1986) | def __init__(
    method forward (line 2074) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class CrossAttnUpBlock2D (line 2095) | class CrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2096) | def __init__(
    method forward (line 2181) | def forward(
  class UpBlock2D (line 2243) | class UpBlock2D(nn.Module):
    method __init__ (line 2244) | def __init__(
    method forward (line 2291) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class UpDecoderBlock2D (line 2324) | class UpDecoderBlock2D(nn.Module):
    method __init__ (line 2325) | def __init__(
    method forward (line 2368) | def forward(self, hidden_states, temb=None, scale: float = 1.0):
  class AttnUpDecoderBlock2D (line 2379) | class AttnUpDecoderBlock2D(nn.Module):
    method __init__ (line 2380) | def __init__(
    method forward (line 2447) | def forward(self, hidden_states, temb=None, scale: float = 1.0):
  class AttnSkipUpBlock2D (line 2460) | class AttnSkipUpBlock2D(nn.Module):
    method __init__ (line 2461) | def __init__(
    method forward (line 2551) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, s...
  class SkipUpBlock2D (line 2580) | class SkipUpBlock2D(nn.Module):
    method __init__ (line 2581) | def __init__(
    method forward (line 2649) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, s...
  class ResnetUpsampleBlock2D (line 2675) | class ResnetUpsampleBlock2D(nn.Module):
    method __init__ (line 2676) | def __init__(
    method forward (line 2742) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class SimpleCrossAttnUpBlock2D (line 2775) | class SimpleCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2776) | def __init__(
    method forward (line 2872) | def forward(
  class KUpBlock2D (line 2939) | class KUpBlock2D(nn.Module):
    method __init__ (line 2940) | def __init__(
    method forward (line 2987) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class KCrossAttnUpBlock2D (line 3019) | class KCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 3020) | def __init__(
    method forward (line 3103) | def forward(
  class KAttentionBlock (line 3165) | class KAttentionBlock(nn.Module):
    method __init__ (line 3182) | def __init__(
    method _to_3d (line 3225) | def _to_3d(self, hidden_states, height, weight):
    method _to_4d (line 3228) | def _to_4d(self, hidden_states, height, weight):
    method forward (line 3231) | def forward(

FILE: unet_2d/unet_2d_condition.py
  class UNet2DConditionOutput (line 57) | class UNet2DConditionOutput(BaseOutput):
  class UNet2DConditionModel (line 69) | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
    method __init__ (line 161) | def __init__(
    method attn_processors (line 593) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 616) | def set_attn_processor(
    method set_default_attn_processor (line 652) | def set_default_attn_processor(self):
    method set_attention_slice (line 667) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 732) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 736) | def forward(

FILE: utils/dist_tools.py
  function distributed_init (line 18) | def distributed_init(args):
  function get_rank (line 62) | def get_rank():
  function is_master (line 72) | def is_master():
  function synchronize (line 76) | def synchronize():
  function suppress_output (line 81) | def suppress_output(is_master):

FILE: utils/inference_helpers.py
  function prepare_results_dir (line 32) | def prepare_results_dir(config, ckpt_path, root_dst_dir):
  function get_dataset_info (line 44) | def get_dataset_info(test_dir):
  function load_pipeline (line 57) | def load_pipeline(config, pretrained_model_path, pretrained_clip_path, f...
  class PE_wrapper (line 127) | class PE_wrapper():
    method __init__ (line 128) | def __init__(self, config,full_state_dict, device, dtype):
    method norm (line 153) | def norm(self, pos):
    method embed (line 160) | def embed(self, cur_pos):
  function pad_to_16 (line 186) | def pad_to_16(source_image):
  class TP_wrapper (line 228) | class TP_wrapper(nn.Module):
    method __init__ (line 229) | def __init__(self, config, full_state_dict, device, dtype):
    method forward (line 239) | def forward(self, cur_img_path, ref_img_path, cache_path=None):
    method get_negative_embeddings (line 242) | def get_negative_embeddings(self):
    method get_negative_embeddings_exclude (line 245) | def get_negative_embeddings_exclude(self, next_text):
    method encode_text_prompt (line 248) | def encode_text_prompt(self, next_prompt):
  class TP_text_wrapper (line 251) | class TP_text_wrapper(nn.Module):
    method __init__ (line 252) | def __init__(self, config,full_state_dict, device, dtype):
    method forward (line 306) | def forward(self, cur_img_path, ref_img_path, cache_path=None):
    method encode_text_prompt (line 337) | def encode_text_prompt(self, next_prompt):
    method get_negative_embeddings (line 357) | def get_negative_embeddings(self):
  class RP_wrapper (line 380) | class RP_wrapper(nn.Module):
    method __init__ (line 381) | def __init__(self, config, full_state_dict, device, dtype):
    method read_RP_mask (line 459) | def read_RP_mask(self, mask_path):
    method forward (line 467) | def forward(self, cur_img_path, ref_img_path, next_prompt=None, next_R...

FILE: utils/llava_utils.py
  function image_parser (line 28) | def image_parser(args):
  function load_image (line 33) | def load_image(image_file):
  function load_images (line 42) | def load_images(image_files):
  class Predictor (line 50) | class Predictor:
    method __init__ (line 51) | def __init__(self, args) -> None:
    method set_args (line 65) | def set_args(self, args):
    method eval_model (line 68) | def eval_model(self):

FILE: utils/text_wrapper.py
  function import_model_class_from_model_name_or_path (line 6) | def import_model_class_from_model_name_or_path(pretrained_model_name_or_...
  function tokenize_prompt (line 30) | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
  function encode_prompt (line 47) | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_...

FILE: utils/util.py
  function zero_rank_print (line 22) | def zero_rank_print(s):
  function save_videos_grid (line 25) | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_r...
  function save_images_grid (line 39) | def save_images_grid(images: torch.Tensor, path: str):
  function init_prompt (line 49) | def init_prompt(prompt, pipeline):
  function next_step (line 68) | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timest...
  function get_noise_pred_single (line 81) | def get_noise_pred_single(latents, t, context, unet):
  function ddim_loop (line 87) | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
  function ddim_inversion (line 101) | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps...
  function video2images (line 106) | def video2images(path, step=4, length=16, start=0):
  function images2video (line 115) | def images2video(video, path, fps=8):
  function get_tensor_interpolation_method (line 122) | def get_tensor_interpolation_method():
  function set_tensor_interpolation_method (line 125) | def set_tensor_interpolation_method(is_slerp):
  function linear (line 129) | def linear(v1, v2, t):
  function slerp (line 132) | def slerp(
Condensed preview — 125 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (9,802K chars).
[
  {
    "path": "README.md",
    "chars": 7730,
    "preview": "<h1 align='Center'>Inverse Painting: Reconstructing The Painting Process</h1>\n\n<div align='Center'>\n            <a href="
  },
  {
    "path": "configs/inference/inference.yaml",
    "chars": 657,
    "preview": "unet_additional_kwargs:\r\n  unet_use_cross_frame_attention: false\r\n  unet_use_temporal_attention: false\r\n  use_motion_mod"
  },
  {
    "path": "configs/train/train_mask_gen.yaml",
    "chars": 1880,
    "preview": "image_finetune: true\n\noutput_dir: \"outputs/mask_gen\"\npretrained_model_path: \"../base_ckpt/realisticVisionV51_v51VAE\"\n\nno"
  },
  {
    "path": "configs/train/train_renderer.yaml",
    "chars": 2076,
    "preview": "image_finetune: true\n\noutput_dir: \"outputs/renderer\"\npretrained_model_path: \"../base_ckpt/realisticVisionV51_v51VAE\"\ncli"
  },
  {
    "path": "data/sample_data/train/rgb/example/last_aligned_frame_inv.json",
    "chars": 187,
    "preview": "{\n    \"10_3:21\": [\n        \"1_0:15\",\n        \"2_0:40\",\n        \"3_0:53\",\n        \"4_1:12\",\n        \"5_1:28\",\n        \"6_"
  },
  {
    "path": "data/sample_data_processed/train/llava_json.json",
    "chars": 4653,
    "preview": "[\n    {\n        \"id\": 0,\n        \"image\": \"example/white_10_3:21_10_3:21.png\",\n        \"conversations\": [\n            {\n"
  },
  {
    "path": "data/sample_data_processed/train/rgb/example/last_aligned_frame_inv.json",
    "chars": 187,
    "preview": "{\n    \"10_3:21\": [\n        \"1_0:15\",\n        \"2_0:40\",\n        \"3_0:53\",\n        \"4_1:12\",\n        \"5_1:28\",\n        \"6_"
  },
  {
    "path": "data/sample_data_processed/train/text/example/2_0:40.json",
    "chars": 332,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"2_0:40\", \"next_image_name\": \"3_0:53\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/3_0:53.json",
    "chars": 340,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"3_0:53\", \"next_image_name\": \"4_1:12\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/4_1:12.json",
    "chars": 329,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"4_1:12\", \"next_image_name\": \"5_1:28\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/5_1:28.json",
    "chars": 329,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"5_1:28\", \"next_image_name\": \"6_1:58\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/6_1:58.json",
    "chars": 329,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"6_1:58\", \"next_image_name\": \"7_2:19\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/7_2:19.json",
    "chars": 329,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"7_2:19\", \"next_image_name\": \"8_2:36\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/8_2:36.json",
    "chars": 329,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"8_2:36\", \"next_image_name\": \"9_2:43\", \"prompt\": \"There are two images sid"
  },
  {
    "path": "data/sample_data_processed/train/text/example/9_2:43.json",
    "chars": 335,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"9_2:43\", \"next_image_name\": \"10_3:21\", \"prompt\": \"There are two images si"
  },
  {
    "path": "data/sample_data_processed/train/text/example/white_10_3:21.json",
    "chars": 344,
    "preview": "{\"ref_img_name\": \"10_3:21\", \"cur_image_name\": \"white_10_3:21\", \"next_image_name\": \"2_0:40\", \"prompt\": \"There are two ima"
  },
  {
    "path": "data_processing/run_llava/main.py",
    "chars": 5705,
    "preview": "import argparse\nfrom llava.model.builder import load_pretrained_model\nfrom llava.mm_utils import get_model_name_from_pat"
  },
  {
    "path": "data_processing/run_llava/make_list.py",
    "chars": 3429,
    "preview": "import json\nimport os\nimport cv2\nimport numpy as np\nimport tqdm\nfrom pathlib import Path\nimport argparse\n\n# Function to "
  },
  {
    "path": "data_processing/run_llava/utils.py",
    "chars": 5490,
    "preview": "import argparse\nimport torch\nfrom PIL import Image, ImageDraw, ImageFont\nimport os\nfrom tqdm import tqdm\nfrom llava.cons"
  },
  {
    "path": "data_processing/run_lpips/main.py",
    "chars": 5104,
    "preview": "import glob\nimport os\nimport json\nimport numpy as np\nfrom PIL import Image\nimport tqdm\nimport matplotlib.pyplot as plt\ni"
  },
  {
    "path": "dataset/dataset.py",
    "chars": 14013,
    "preview": "import os, io, csv, math, random\r\nimport numpy as np\r\nfrom PIL import Image\r\n\r\nimport torch\r\nimport torchvision.transfor"
  },
  {
    "path": "demo.py",
    "chars": 16418,
    "preview": "import argparse\nimport datetime\nimport os\nimport random\nimport numpy as np\nfrom PIL import Image, ImageDraw, ImageFont\nf"
  },
  {
    "path": "models/ReferenceEncoder.py",
    "chars": 2128,
    "preview": "import torch\r\nimport torch.nn as nn\r\nfrom PIL import Image\r\nfrom transformers import CLIPProcessor, CLIPVisionModel, CLI"
  },
  {
    "path": "models/ReferenceNet.py",
    "chars": 58606,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"L"
  },
  {
    "path": "models/ReferenceNet_attention.py",
    "chars": 15091,
    "preview": "# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py\r\n\r"
  },
  {
    "path": "models/ReferenceNet_attention_fp16.py",
    "chars": 12529,
    "preview": "# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py\n\ni"
  },
  {
    "path": "models/attention.py",
    "chars": 13503,
    "preview": "# *************************************************************************\r\n# This file may have been modified by Byted"
  },
  {
    "path": "models/clip_adapter.py",
    "chars": 2077,
    "preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass NextImageFeaturePredictor(nn.Module):\n    def"
  },
  {
    "path": "models/hack_cur_image_guider.py",
    "chars": 3879,
    "preview": "import os\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom einops import rearrange\nimport numpy as n"
  },
  {
    "path": "models/hack_unet2d.py",
    "chars": 16645,
    "preview": "from dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch."
  },
  {
    "path": "models/image_processor.py",
    "chars": 696,
    "preview": "import torch\nimport os\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as T\n\n\nclass "
  },
  {
    "path": "models/orig_attention.py",
    "chars": 44301,
    "preview": "# *************************************************************************\r\n# This file may have been modified by Byted"
  },
  {
    "path": "models/positional_encoder.py",
    "chars": 2930,
    "preview": "\nimport torch\n# torch.autograd.set_detect_anomaly(True)\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport num"
  },
  {
    "path": "models/resnet.py",
    "chars": 8010,
    "preview": "# *************************************************************************\r\n# This file may have been modified by Byted"
  },
  {
    "path": "models/unet.py",
    "chars": 21526,
    "preview": "from dataclasses import dataclass\r\nfrom typing import List, Optional, Tuple, Union\r\n\r\nimport os\r\nimport json\r\n# import p"
  },
  {
    "path": "models/unet_3d_blocks.py",
    "chars": 28463,
    "preview": "import torch\r\nfrom torch import nn\r\n\r\nfrom .attention import Transformer3DModel\r\nfrom .resnet import Downsample3D, Resne"
  },
  {
    "path": "pipelines/context.py",
    "chars": 2270,
    "preview": "# *************************************************************************\n# This file may have been modified by Byteda"
  },
  {
    "path": "pipelines/pipeline_stage_1.py",
    "chars": 25461,
    "preview": "# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py\n\ni"
  },
  {
    "path": "requirements.txt",
    "chars": 2473,
    "preview": "absl-py==2.1.0\naccelerate==0.21.0\naiofiles==23.2.1\naltair==5.2.0\nannotated-types==0.6.0\nantlr4-python3-runtime==4.9.3\nan"
  },
  {
    "path": "training_scripts/llava/__init__.py",
    "chars": 41,
    "preview": "from .model import LlavaLlamaForCausalLM\n"
  },
  {
    "path": "training_scripts/llava/constants.py",
    "chars": 335,
    "preview": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \".\"\n\n# Model Constants\nIGNORE_INDEX = -1"
  },
  {
    "path": "training_scripts/llava/conversation.py",
    "chars": 15022,
    "preview": "import dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple\nimport base64\nfrom io import BytesIO\nfrom "
  },
  {
    "path": "training_scripts/llava/eval/eval_gpt_review.py",
    "chars": 3620,
    "preview": "import argparse\nimport json\nimport os\n\nimport openai\nimport tqdm\nimport ray\nimport time\n\nNUM_SECONDS_TO_SLEEP = 3\n\n@ray."
  },
  {
    "path": "training_scripts/llava/eval/eval_gpt_review_bench.py",
    "chars": 4172,
    "preview": "import argparse\nimport json\nimport os\n\nimport openai\nimport time\n\nNUM_SECONDS_TO_SLEEP = 0.5\n\n\ndef get_eval(content: str"
  },
  {
    "path": "training_scripts/llava/eval/eval_gpt_review_visual.py",
    "chars": 4177,
    "preview": "import argparse\nimport json\nimport os\n\nimport openai\nimport time\n\nNUM_SECONDS_TO_SLEEP = 0.5\n\n\ndef get_eval(content: str"
  },
  {
    "path": "training_scripts/llava/eval/eval_pope.py",
    "chars": 2732,
    "preview": "import os\nimport json\nimport argparse\n\ndef eval_pope(answers, label_file):\n    label_list = [json.loads(q)['label'] for "
  },
  {
    "path": "training_scripts/llava/eval/eval_science_qa.py",
    "chars": 3920,
    "preview": "import argparse\nimport json\nimport os\nimport re\nimport random\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n "
  },
  {
    "path": "training_scripts/llava/eval/eval_science_qa_gpt4.py",
    "chars": 3675,
    "preview": "import argparse\nimport json\nimport os\nimport re\nimport random\nfrom collections import defaultdict\n\n\ndef get_args():\n    "
  },
  {
    "path": "training_scripts/llava/eval/eval_science_qa_gpt4_requery.py",
    "chars": 5774,
    "preview": "import argparse\nimport json\nimport os\nimport re\nimport random\nfrom collections import defaultdict\n\n\ndef get_args():\n    "
  },
  {
    "path": "training_scripts/llava/eval/eval_textvqa.py",
    "chars": 2226,
    "preview": "import os\nimport argparse\nimport json\nimport re\n\nfrom llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator\n\n\ndef get"
  },
  {
    "path": "training_scripts/llava/eval/generate_webpage_data_from_table.py",
    "chars": 4088,
    "preview": "\"\"\"Generate json file for webpage.\"\"\"\nimport json\nimport os\nimport re\n\n# models = ['llama', 'alpaca', 'gpt35', 'bard']\nm"
  },
  {
    "path": "training_scripts/llava/eval/m4c_evaluator.py",
    "chars": 10265,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport re\n\nfrom tqdm import tqdm\n\n\nclass EvalAIAnswerProcessor:\n    \""
  },
  {
    "path": "training_scripts/llava/eval/model_qa.py",
    "chars": 2430,
    "preview": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria\nimport torch\nimport os\nim"
  },
  {
    "path": "training_scripts/llava/eval/model_vqa.py",
    "chars": 4115,
    "preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import I"
  },
  {
    "path": "training_scripts/llava/eval/model_vqa_loader.py",
    "chars": 5975,
    "preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import I"
  },
  {
    "path": "training_scripts/llava/eval/model_vqa_mmbench.py",
    "chars": 6408,
    "preview": "import argparse\nimport torch\nimport os\nimport json\nimport pandas as pd\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llav"
  },
  {
    "path": "training_scripts/llava/eval/model_vqa_science.py",
    "chars": 4592,
    "preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import I"
  },
  {
    "path": "training_scripts/llava/eval/qa_baseline_gpt35.py",
    "chars": 2345,
    "preview": "\"\"\"Generate answers with GPT-3.5\"\"\"\n# Note: you need to be using OpenAI Python v0.27.0 for the code below to work\nimport"
  },
  {
    "path": "training_scripts/llava/eval/run_llava.py",
    "chars": 4443,
    "preview": "import argparse\nimport torch\n\nfrom llava.constants import (\n    IMAGE_TOKEN_INDEX,\n    DEFAULT_IMAGE_TOKEN,\n    DEFAULT_"
  },
  {
    "path": "training_scripts/llava/eval/summarize_gpt_review.py",
    "chars": 2438,
    "preview": "import json\nimport os\nfrom collections import defaultdict\n\nimport numpy as np\n\nimport argparse\n\ndef parse_args():\n    pa"
  },
  {
    "path": "training_scripts/llava/eval/table/answer/answer_alpaca-13b.jsonl",
    "chars": 57071,
    "preview": "{\"question_id\": 1, \"text\": \"Improving time management skills involves setting priorities, breaking tasks into smaller ch"
  },
  {
    "path": "training_scripts/llava/eval/table/answer/answer_bard.jsonl",
    "chars": 112274,
    "preview": "{\"answer_id\": \"3oW4JY265ZPJGTYi2CgRYF\", \"model_id\": \"bard:20230327\", \"question_id\": 1, \"text\": \"Here are some tips on ho"
  },
  {
    "path": "training_scripts/llava/eval/table/answer/answer_gpt35.jsonl",
    "chars": 107603,
    "preview": "{\"answer_id\": \"BZGowHM7L3RvtWRktKZjLT\", \"model_id\": \"gpt-3.5-turbo:20230327\", \"question_id\": 1, \"text\": \"Here are some t"
  },
  {
    "path": "training_scripts/llava/eval/table/answer/answer_llama-13b.jsonl",
    "chars": 76353,
    "preview": "{\"answer_id\": \"J3UA6eGXGyFeUGqGpP3g34\", \"model_id\": \"llama-13b:v1\", \"question_id\": 1, \"text\": \"The following are some st"
  },
  {
    "path": "training_scripts/llava/eval/table/answer/answer_vicuna-13b.jsonl",
    "chars": 131904,
    "preview": "{\"answer_id\": \"cV4zXygaNP6CXEsgdHMEqz\", \"model_id\": \"vicuna-13b:20230322-clean-lang\", \"question_id\": 1, \"text\": \"Improvi"
  },
  {
    "path": "training_scripts/llava/eval/table/caps_boxes_coco2014_val_80.jsonl",
    "chars": 58574,
    "preview": "{\"id\": \"000000296284\", \"image\": \"000000296284.jpg\", \"captions\": [\"A donut shop is full of different flavors of donuts.\","
  },
  {
    "path": "training_scripts/llava/eval/table/model.jsonl",
    "chars": 681,
    "preview": "{\"model_id\": \"vicuna-13b:20230322-clean-lang\", \"model_name\": \"vicuna-13b\", \"model_version\": \"20230322-clean-lang\", \"mode"
  },
  {
    "path": "training_scripts/llava/eval/table/prompt.jsonl",
    "chars": 5129,
    "preview": "{\"prompt_id\": 1, \"system_prompt\": \"You are a helpful and precise assistant for checking the quality of the answer.\", \"pr"
  },
  {
    "path": "training_scripts/llava/eval/table/question.jsonl",
    "chars": 12885,
    "preview": "{\"question_id\": 1, \"text\": \"How can I improve my time management skills?\", \"category\": \"generic\"}\n{\"question_id\": 2, \"te"
  },
  {
    "path": "training_scripts/llava/eval/table/results/test_sqa_llava_13b_v0.json",
    "chars": 3950324,
    "preview": "{\n  \"acc\": 90.8983730252299,\n  \"correct\": 3855,\n  \"count\": 4241,\n  \"results\": {\n    \"4\": 1,\n    \"5\": 1,\n    \"11\": 1,\n   "
  },
  {
    "path": "training_scripts/llava/eval/table/results/test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json",
    "chars": 3830902,
    "preview": "{\n  \"acc\": 91.08700778118369,\n  \"correct\": 3863,\n  \"count\": 4241,\n  \"results\": {\n    \"4\": 1,\n    \"5\": 1,\n    \"11\": 1,\n  "
  },
  {
    "path": "training_scripts/llava/eval/table/review/review_alpaca-13b_vicuna-13b.jsonl",
    "chars": 73131,
    "preview": "{\"review_id\": \"QM5m5nnioWr8M2LFHsaQvu\", \"question_id\": 1, \"answer1_id\": \"kEL9ifUHDeYuAXzevje2se\", \"answer2_id\": \"cV4zXyg"
  },
  {
    "path": "training_scripts/llava/eval/table/review/review_bard_vicuna-13b.jsonl",
    "chars": 73145,
    "preview": "{\"review_id\": \"4CeMvEQyE6fKMJwvSLY3P4\", \"question_id\": 1, \"answer1_id\": \"3oW4JY265ZPJGTYi2CgRYF\", \"answer2_id\": \"cV4zXyg"
  },
  {
    "path": "training_scripts/llava/eval/table/review/review_gpt35_vicuna-13b.jsonl",
    "chars": 73399,
    "preview": "{\"review_id\": \"jyhS7AFj2mrFNqoRXQJDPS\", \"question_id\": 1, \"answer1_id\": \"BZGowHM7L3RvtWRktKZjLT\", \"answer2_id\": \"cV4zXyg"
  },
  {
    "path": "training_scripts/llava/eval/table/review/review_llama-13b_vicuna-13b.jsonl",
    "chars": 67249,
    "preview": "{\"review_id\": \"WFp5i5yjjFethrgugKTDmX\", \"question_id\": 1, \"answer1_id\": \"J3UA6eGXGyFeUGqGpP3g34\", \"answer2_id\": \"cV4zXyg"
  },
  {
    "path": "training_scripts/llava/eval/table/reviewer.jsonl",
    "chars": 604,
    "preview": "{\"reviewer_id\": \"gpt-4-0328-default\", \"prompt_id\": 1, \"metadata\": {\"temperature\": 0.2, \"max_tokens\": 1024}, \"description"
  },
  {
    "path": "training_scripts/llava/eval/table/rule.json",
    "chars": 9098,
    "preview": "{\n    \"coding\": {\"role\": \"Assistant\", \"prompt\": \"Your task is to evaluate the coding abilities of the above two assistan"
  },
  {
    "path": "training_scripts/llava/eval/webpage/index.html",
    "chars": 7664,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width"
  },
  {
    "path": "training_scripts/llava/eval/webpage/script.js",
    "chars": 9967,
    "preview": "// Description: Script for the evaluation webpage.\n\nlet currentQuestionIndex = 1;\n\n// Store the model name mapping for l"
  },
  {
    "path": "training_scripts/llava/eval/webpage/styles.css",
    "chars": 1822,
    "preview": "body {\n    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n    background-color: #f8f9fa;\n}\n\n.navbar-dark "
  },
  {
    "path": "training_scripts/llava/mm_utils.py",
    "chars": 9593,
    "preview": "from PIL import Image\nfrom io import BytesIO\nimport base64\nimport torch\nimport math\nimport ast\n\nfrom transformers import"
  },
  {
    "path": "training_scripts/llava/model/__init__.py",
    "chars": 269,
    "preview": "try:\n    from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig\n    from .language_model.llava_mpt i"
  },
  {
    "path": "training_scripts/llava/model/apply_delta.py",
    "chars": 1956,
    "preview": "\"\"\"\nUsage:\npython3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --de"
  },
  {
    "path": "training_scripts/llava/model/builder.py",
    "chars": 9039,
    "preview": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not "
  },
  {
    "path": "training_scripts/llava/model/consolidate.py",
    "chars": 914,
    "preview": "\"\"\"\nUsage:\npython3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate\n"
  },
  {
    "path": "training_scripts/llava/model/language_model/llava_llama.py",
    "chars": 5424,
    "preview": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not "
  },
  {
    "path": "training_scripts/llava/model/language_model/llava_mistral.py",
    "chars": 5386,
    "preview": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not "
  },
  {
    "path": "training_scripts/llava/model/language_model/llava_mpt.py",
    "chars": 3487,
    "preview": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not "
  },
  {
    "path": "training_scripts/llava/model/llava_arch.py",
    "chars": 18110,
    "preview": "#    Copyright 2023 Haotian Liu\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not "
  },
  {
    "path": "training_scripts/llava/model/make_delta.py",
    "chars": 2257,
    "preview": "\"\"\"\nUsage:\npython3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~"
  },
  {
    "path": "training_scripts/llava/model/multimodal_encoder/builder.py",
    "chars": 556,
    "preview": "import os\nfrom .clip_encoder import CLIPVisionTower\n\n\ndef build_vision_tower(vision_tower_cfg, **kwargs):\n    vision_tow"
  },
  {
    "path": "training_scripts/llava/model/multimodal_encoder/clip_encoder.py",
    "chars": 3062,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig\n\n\ncla"
  },
  {
    "path": "training_scripts/llava/model/multimodal_projector/builder.py",
    "chars": 1437,
    "preview": "import torch\nimport torch.nn as nn\nimport re\n\n\nclass IdentityMap(nn.Module):\n    def __init__(self):\n        super().__i"
  },
  {
    "path": "training_scripts/llava/model/utils.py",
    "chars": 927,
    "preview": "from transformers import AutoConfig\n\n\ndef auto_upgrade(config):\n    cfg = AutoConfig.from_pretrained(config)\n    if 'lla"
  },
  {
    "path": "training_scripts/llava/serve/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "training_scripts/llava/serve/cli.py",
    "chars": 4808,
    "preview": "import argparse\nimport torch\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN"
  },
  {
    "path": "training_scripts/llava/serve/controller.py",
    "chars": 9949,
    "preview": "\"\"\"\nA controller manages distributed workers.\nIt sends worker addresses to clients.\n\"\"\"\nimport argparse\nimport asyncio\ni"
  },
  {
    "path": "training_scripts/llava/serve/gradio_web_server.py",
    "chars": 18841,
    "preview": "import argparse\nimport datetime\nimport json\nimport os\nimport time\n\nimport gradio as gr\nimport requests\n\nfrom llava.conve"
  },
  {
    "path": "training_scripts/llava/serve/model_worker.py",
    "chars": 11176,
    "preview": "\"\"\"\nA model worker executes the model.\n\"\"\"\nimport argparse\nimport asyncio\nimport json\nimport time\nimport threading\nimpor"
  },
  {
    "path": "training_scripts/llava/serve/register_worker.py",
    "chars": 734,
    "preview": "\"\"\"\nManually register workers.\n\nUsage:\npython3 -m fastchat.serve.register_worker --controller http://localhost:21001 --w"
  },
  {
    "path": "training_scripts/llava/serve/sglang_worker.py",
    "chars": 8678,
    "preview": "\"\"\"\nA model worker executes the model.\n\"\"\"\nimport argparse\nimport asyncio\nfrom concurrent.futures import ThreadPoolExecu"
  },
  {
    "path": "training_scripts/llava/serve/test_message.py",
    "chars": 2022,
    "preview": "import argparse\nimport json\n\nimport requests\n\nfrom llava.conversation import default_conversation\n\n\ndef main():\n    if a"
  },
  {
    "path": "training_scripts/llava/train/llama_flash_attn_monkey_patch.py",
    "chars": 4404,
    "preview": "from typing import Optional, Tuple\nimport warnings\n\nimport torch\n\nimport transformers\nfrom transformers.models.llama.mod"
  },
  {
    "path": "training_scripts/llava/train/llama_xformers_attn_monkey_patch.py",
    "chars": 4916,
    "preview": "\"\"\"\nDirectly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_a"
  },
  {
    "path": "training_scripts/llava/train/llava_trainer.py",
    "chars": 11076,
    "preview": "import os\nimport torch\nimport torch.nn as nn\n\nfrom torch.utils.data import Sampler\n\nfrom transformers import Trainer\nfro"
  },
  {
    "path": "training_scripts/llava/train/train.py",
    "chars": 39810,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "training_scripts/llava/train/train_mem.py",
    "chars": 156,
    "preview": "import wandb\nwandb.init(mode='disabled')\nfrom llava.train.train import train\n\nif __name__ == \"__main__\":\n    train(attn_"
  },
  {
    "path": "training_scripts/llava/train/train_xformers.py",
    "chars": 366,
    "preview": "# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.\n\n# Need to call this before "
  },
  {
    "path": "training_scripts/llava/utils.py",
    "chars": 4003,
    "preview": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\n\nimport requests\n\nfrom llava.constants impor"
  },
  {
    "path": "training_scripts/merge_ckpt.sh",
    "chars": 205,
    "preview": " python  scripts/merge_lora_weights.py     --model-path  ./checkpoints/llava-v1.5-7b-task-lora   --model-base  liuhaotia"
  },
  {
    "path": "training_scripts/scripts/merge_lora_weights.py",
    "chars": 767,
    "preview": "import argparse\nfrom llava.model.builder import load_pretrained_model\nfrom llava.mm_utils import get_model_name_from_pat"
  },
  {
    "path": "training_scripts/scripts/zero2.json",
    "chars": 557,
    "preview": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_"
  },
  {
    "path": "training_scripts/train_mask_generator.py",
    "chars": 35340,
    "preview": "import os\nimport math\nimport sys \nsys.path.append('..')\nimport wandb\nimport random\nimport logging\nimport inspect\nimport "
  },
  {
    "path": "training_scripts/train_renderer.py",
    "chars": 32463,
    "preview": "import os\nimport sys \nsys.path.append('../')\nimport math\nimport wandb\nimport random\nimport logging\nimport inspect\nimport"
  },
  {
    "path": "training_scripts/train_text_generator.sh",
    "chars": 1256,
    "preview": "#!/bin/bash\n\ndeepspeed --master_port=12618 llava/train/train_mem.py \\\n    --lora_enable True --lora_r 128 --lora_alpha 2"
  },
  {
    "path": "unet_2d/attention.py",
    "chars": 17374,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "unet_2d/resnet.py",
    "chars": 36096,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The"
  },
  {
    "path": "unet_2d/unet_2d_blocks.py",
    "chars": 125047,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "unet_2d/unet_2d_condition.py",
    "chars": 49269,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/dist_tools.py",
    "chars": 3081,
    "preview": "# Copyright 2023 ByteDance and/or its affiliates.\r\n#\r\n# Copyright (2023) MagicAnimate Authors\r\n#\r\n# ByteDance, its affil"
  },
  {
    "path": "utils/inference_helpers.py",
    "chars": 19014,
    "preview": "\nfrom diffusers import AutoencoderKL, EulerAncestralDiscreteScheduler, DDIMScheduler\nfrom pathlib import Path\nimport tor"
  },
  {
    "path": "utils/llava_utils.py",
    "chars": 5565,
    "preview": "import argparse\nimport torch\n\nfrom llava.constants import (\n    IMAGE_TOKEN_INDEX,\n    DEFAULT_IMAGE_TOKEN,\n    DEFAULT_"
  },
  {
    "path": "utils/text_wrapper.py",
    "chars": 1872,
    "preview": "\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom transformers import AutoTokenizer, PretrainedConfig\n\n\ndef im"
  },
  {
    "path": "utils/util.py",
    "chars": 5287,
    "preview": "# *************************************************************************\r\n# This file may have been modified by Byted"
  }
]

About this extraction

This page contains the full source code of the ArmastusChen/inverse_painting GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 125 files (9.2 MB), approximately 2.4M tokens, and a symbol index with 695 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!