Full Code of muzishen/IMAGHarmony for AI

main 46895e751ff6 cached
20 files
175.2 KB
41.7k tokens
142 symbols
1 requests
Download .txt
Repository: muzishen/IMAGHarmony
Branch: main
Commit: 46895e751ff6
Files: 20
Total size: 175.2 KB

Directory structure:
gitextract_ymnget5p/

├── LICENSE
├── README.md
├── baseline.py
├── convert_bin.py
├── demo.py
├── ip_adapter/
│   ├── __init__.py
│   ├── attention_processor.py
│   ├── custom_pipelines.py
│   ├── ip_adapter.py
│   ├── ip_adapter_origin.py
│   ├── resampler.py
│   ├── shared_models.py
│   ├── test_resampler.py
│   └── utils.py
├── requirements.txt
├── run.sh
├── sdxl-fine-tuning/
│   └── data/
│       └── train.json
├── shared_models.py
├── test.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout



<a href='https://revive234.github.io/IMAGHarmony.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://arxiv.org/pdf/2506.01949'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
<a href='https://huggingface.co/kkkkggg/IMAGHarmony'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
<a href=''><img src='https://img.shields.io/badge/Dataset-HarmonyBench-orange'></a>



## 🗓️ Release
- [2025/5/30] 🔥 We released the [technical report](https://arxiv.org/pdf/2506.01949) of IMAGHarmony.
- [2025/5/28] 🔥 We release the train and inference code of IMAGHarmony.
- [2025/5/17] 🎉 We launch the [project page](https://revive234.github.io/IMAGHarmony.github.io/) of IMAGHarmony.








## 💡 Introduction
IMAGHarmony tackles the challenge of controllable image editing in multi-object scenes, where existing models struggle with aligning object quantity and spatial layout.
To this end, IMAGHarmony introduces a structure-aware framework for quantity-and-layout consistent image editing (QL-Edit), enabling precise control over object count, category, and arrangement.
We propose a harmony aware (HA) mudule to jointly model object structure and semantics, and a preference-guided noise selection (PNS) strategy to stabilize generation by selecting semantically aligned initial noise.
Our method is trained and evaluated on HarmonyBench, a newly curated benchmark with diverse editing scenarios.

![architecture](./assets/1.png)

## 🚀 HarmonyBench Dataset Demo


![dataset_demo](./assets/harmonybench.jpg)
## 🚀 Examples

![results_1](./assets/sotacomp.jpg)

![results_2](./assets/multi.jpg)


### Dual-Category Editing
![results_5](./assets/3edit.jpg)





## 🔧 Requirements

- Python>=3.8
- [PyTorch>=2.0.0](https://pytorch.org/)
- cuda>=11.8
```
conda create --name IMAGHarmony python=3.8.18
conda activate IMAGHarmony

# Install requirements
pip install -r requirements.txt
```
## 🌐 Download Models

You can download our models from [Huggingface](https://huggingface.co/kkkkggg/IMAGHarmony). You can download the other component models from the original repository, as follows.
- [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
- [stable-diffusion-XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)

## 🚀 How to train
```
# Please download the HarmonyBench data first or prepare your own images
# and modify the path in run.sh
## Write caption of your image in your train.json file 
# start training

sh train.sh
```
## 🚀 How to test
```
#Please convert your checkpionts
python conver_bin.py

#Please fill in your path in test.py
#then run

python test.py
```
Or you may like to test it on gradio
```
python demo.py
```


## Acknowledgement
We would like to thank the contributors to the [Instantstyle](https://github.com/instantX-research/InstantStyle) and [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) repositories, for their open research and exploration.

The IMAGHarmony code is available for both academic and commercial use. Users are permitted to generate images using this tool, provided they comply with local laws and exercise responsible use. The developers disclaim all liability for any misuse or unlawful activity by users.
## Citation
If you find IMAGHarmony useful for your research and applications, please cite using this BibTeX:

```bibtex
@misc{shen2025imagharmonycontrollableimageediting,
      title={IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout}, 
      author={Fei Shen and Yutong Gao and Jian Yu and Xiaoyu Du and Jinhui Tang},
      year={2025},
      eprint={2506.01949},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2506.01949}, 
}
```

## 🕒 TODO List
- [x] Paper
- [x] Train Code
- [x] Inference Code
- [ ] HarmonyBench Dataset
- [ ] Model Weights

## 👉 **Our other projects:**  
- [IMAGEdit](https://github.com/XWH-A/IMAGEdit): Training-Free Controllable Video Editing with Consistent Object Layout.  [可控多目标视频编辑]
- [IMAGDressing](https://github.com/muzishen/IMAGDressing): Controllable dressing generation. [可控穿衣生成]
- [IMAGGarment](https://github.com/muzishen/IMAGGarment): Fine-grained controllable garment generation.  [可控服装生成]
- [IMAGHarmony](https://github.com/muzishen/IMAGHarmony): Controllable image editing with consistent object layout.  [可控多目标图像编辑]
- [IMAGPose](https://github.com/muzishen/IMAGPose): Pose-guided person generation with high fidelity.  [可控多模式人物生成]
- [RCDMs](https://github.com/muzishen/RCDMs): Rich-contextual conditional diffusion for story visualization.  [可控故事生成]
- [PCDMs](https://github.com/tencent-ailab/PCDMs): Progressive conditional diffusion for pose-guided image synthesis. [可控人物生成]
- [V-Express](https://github.com/tencent-ailab/V-Express/): Explores strong and weak conditional relationships for portrait video generation. [可控数字人生成]
- [FaceShot](https://github.com/open-mmlab/FaceShot/): Talkingface plugin for any character. [可控动漫数字人生成]
- [CharacterShot](https://github.com/Jeoyal/CharacterShot): Controllable and consistent 4D character animation framework. [可控4D角色生成]
- [StyleTailor](https://github.com/mahb-THU/StyleTailor): An Agent for personalized fashion styling. [个性化时尚Agent]
- [SignVip](https://github.com/umnooob/signvip/): Controllable sign language video generation. [可控手语生成]

## 📨 Contact
If you have any questions, please feel free to contact with us at shenfei140721@126.com and yutonggaokkk@njust.edu.cn.


================================================
FILE: baseline.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

class QFormer(nn.Module):
    def __init__(self, 
                 hidden_dim=768,
                 num_queries=16,
                 num_layers=6,
                 num_heads=12,
                 image_feat_dim=320,
                 text_feat_dim=2048,
                 add_modality_embedding=True):
        super(QFormer, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_queries = num_queries
        self.add_modality_embedding = add_modality_embedding

        # Learnable query tokens: [1, num_queries, D]
        self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim))

        # Modality type embeddings (0=image, 1=text)
        if add_modality_embedding:
            self.modality_embed = nn.Embedding(2, hidden_dim)
        self.image_proj = nn.Linear(image_feat_dim, hidden_dim)
        self.text_proj = nn.Linear(text_feat_dim, hidden_dim)
        # Transformer encoder: Q interacts with K/V (image + text)
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, image_feat, text_feat):
        """
        image_feat: [B, T_img, D]
        text_feat:  [B, T_txt, D]
        """
        
        B = image_feat.size(0)
        print(image_feat.shape,text_feat.shape)
        image_feat = self.image_proj(image_feat)
        text_feat = self.text_proj(text_feat)
        # Concatenate image + text features as K/V
        kv = torch.cat([image_feat, text_feat], dim=1)  # [B, T_img + T_txt, D]

        # Add modality type embedding
        if self.add_modality_embedding:
            T_img = image_feat.size(1)
            T_txt = text_feat.size(1)
            modality_ids = torch.cat([
                torch.zeros(T_img, dtype=torch.long),
                torch.ones(T_txt, dtype=torch.long)
            ], dim=0).to(image_feat.device)  # [T_img + T_txt]
            modality_embed = self.modality_embed(modality_ids)  # [T_img + T_txt, D]
            kv = kv + modality_embed.unsqueeze(0)  # broadcast to [B, T, D]

        # Expand learnable query tokens to batch size
        queries = self.query_tokens.expand(B, -1, -1)  # [B, N_query, D]

        # Q-Former: let queries attend to K/V
        # Transformer requires concat(Q, K/V)
        input_seq = torch.cat([queries, kv], dim=1)  # [B, N_query + T, D]
        output = self.transformer(input_seq)  # [B, N_query + T, D]

        # Return only the updated query tokens
        return output[:, :self.num_queries, :]  # [B, N_query, D]
    

class MLP(nn.Module):
    def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header =16):
        super().__init__()
        self.image_proj = nn.Linear(image_dim, fused_dim)
        self.text_proj = nn.Linear(text_dim, fused_dim)
        self.num_header = num_header
        self.mlp = nn.Sequential(
            nn.Linear(2 * fused_dim, fused_dim),
            nn.ReLU(),
            nn.Linear(fused_dim, fused_dim),
            nn.ReLU(),
            nn.Linear(fused_dim,fused_dim*16)
        )
        self.fused_dim=fused_dim
    def forward(self, image_feat, text_feat):
        """
        image_feat: [B, T_img, image_dim]
        text_feat: [B, T_txt, text_dim]
        """
 
        image_repr = image_feat.mean(dim=1)  # [B, image_dim]
        text_repr = text_feat.mean(dim=1)    # [B, text_dim]


        image_proj = self.image_proj(image_repr)  # [B, fused_dim]
        text_proj = self.text_proj(text_repr)     # [B, fused_dim]


        fused = torch.cat([image_proj, text_proj], dim=-1)  # [B, 2*fused_dim]
        
        output = self.mlp(fused).reshape(-1,self.num_header,self.fused_dim)  # [B, fused_dim]
        return output
    



class GatedAttentionFusion(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=512):
        super().__init__()

        self.gate_mlp = nn.Sequential(
            nn.Linear(2 * input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  #  alpha ∈ [0, 1]
        )

    def forward(self, img_feat, txt_feat):
        """
        img_feat: [B, D]
        txt_feat: [B, D]
        """
        fused_input = torch.cat([img_feat, txt_feat], dim=-1)  # [B, 2D]
        alpha = self.gate_mlp(fused_input)  # [B, 1]


        fused = alpha * img_feat + (1 - alpha) * txt_feat  # [B, D]
        return fused

class AttentionFusionWrapper(nn.Module):
    def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header=16):
        super().__init__()
        self.img_proj = nn.Linear(image_dim, fused_dim)
        self.txt_proj = nn.Linear(text_dim, fused_dim)
        self.fusion = GatedAttentionFusion(input_dim=fused_dim)
        self.num_header = num_header
        self.fused_dim = fused_dim
        self.dim_transfer = nn.Linear(fused_dim,fused_dim*self.num_header)
    def forward(self, image_feat, text_feat):
        """
        image_feat: [B, T_img, 320]
        text_feat:  [B, T_txt, 2048]
        """
        # Mean pooling
        img_global = image_feat.mean(dim=1)  # [B, 320]
        txt_global = text_feat.mean(dim=1)  # [B, 2048]

        # Linear projection
        img_proj = self.img_proj(img_global)  # [B, 768]
        txt_proj = self.txt_proj(txt_global)  # [B, 768]

        # Gated fusion
        fused = self.fusion(img_proj, txt_proj)  # [B, 768]
        fused = self.dim_transfer(fused).reshape(-1,self.num_header,self.fused_dim)
        return fused


================================================
FILE: convert_bin.py
================================================
import os
import torch
from collections import OrderedDict

def convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path):
    
    if not os.path.exists(pytorch_model_path):
        print(f"  [Warning] Source file not found, skipping: {pytorch_model_path}")
        return False

    print(f"  Converting: {pytorch_model_path}")
    try:

        sd = torch.load(pytorch_model_path, map_location="cpu")

        image_proj_sd = OrderedDict()
        ip_sd = OrderedDict()
        composed_sd = OrderedDict()
        

        for k in sd:
            if k.startswith("image_proj_model."):
                image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
            elif k.startswith("adapter_modules."):

                ip_sd[k.replace("adapter_modules.", "")] = sd[k] 
            elif k.startswith("composed_modules."):
                composed_sd[k.replace("composed_modules.", "")] = sd[k]


        if not image_proj_sd and not ip_sd and not composed_sd:
             print(f"  [Warning] No expected keys (image_proj_model, adapter_modules, composed_modules) found in {pytorch_model_path}. Skipping save.")
             return False


        final_sd = {
            "image_proj": image_proj_sd, 
            "ip_adapter": ip_sd, 
            'composed_adapter': composed_sd
        }
        

        torch.save(final_sd, output_ip_adapter_path)
        print(f"  Successfully saved: {output_ip_adapter_path}")
        return True

    except Exception as e:
        print(f"  [Error] Failed to convert {pytorch_model_path}: {e}")
        return False








if __name__ == "__main__":
    base_log_dir = "your fine_tuned model path"
    total_converted = 0
    total_skipped = 0
    total_errors = 0

    print(f"Starting conversion process in base directory: {base_log_dir}")


    for training_run_dir_name in os.listdir(base_log_dir):
        training_run_dir_path = os.path.join(base_log_dir, training_run_dir_name)
        
        # Check if it's actually a directory
        if os.path.isdir(training_run_dir_path):
            print(f"\nProcessing training run: {training_run_dir_name}")
            
            # Iterate through items inside the training run directory
            for checkpoint_dir_name in os.listdir(training_run_dir_path):
     
                if checkpoint_dir_name.startswith("checkpoint-") and \
                   os.path.isdir(os.path.join(training_run_dir_path, checkpoint_dir_name)):
                    
                    checkpoint_dir_path = os.path.join(training_run_dir_path, checkpoint_dir_name)
                    print(f"- Found checkpoint directory: {checkpoint_dir_name}")
                    

                    pytorch_model_path = os.path.join(checkpoint_dir_path, "pytorch_model.bin")
                    output_ip_adapter_path = os.path.join(checkpoint_dir_path, "ip_adapter.bin")
                    

                    if os.path.exists(output_ip_adapter_path):
                         print(f"  Output file already exists, skipping: {output_ip_adapter_path}")
                         total_skipped += 1
                         continue 


                    success = convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path)
                    if success:
                        total_converted += 1
                    else:
                       
                         if not os.path.exists(pytorch_model_path):
                            total_skipped += 1 
                         else:
                            total_errors +=1 

    print("\n--- Conversion Summary ---")
    print(f"Total checkpoints converted: {total_converted}")
    print(f"Total checkpoints skipped (e.g., source missing): {total_skipped}")
    print(f"Total errors during conversion: {total_errors}")


================================================
FILE: demo.py
================================================
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from ip_adapter import IPAdapterXL
from huggingface_hub import hf_hub_download
import os
import time

try:
    from tutorial_train_sdxl_ori import ComposedAttention
except ImportError:
    print("Error: Could not import ComposedAttention.")
    print("Please ensure 'tutorial_train_sdxl_ori.py' is in the same directory as this script.")
    exit()

print("Loading models, please wait...")

CKPT_INTER_DIM = 2560
CKPT_CROSS_HEADS = 8
CKPT_RESHAPE_BLOCKS = 8
CKPT_CROSS_VALUE_DIM = 64

BASE_MODEL_PATH = "/aigc_data_hdd/checkpoints/stable-diffusion-xl-base-1.0"
IMAGE_ENCODER_PATH = os.path.join(BASE_MODEL_PATH, "image_encoder")

if not os.path.exists(BASE_MODEL_PATH) or not os.path.exists(IMAGE_ENCODER_PATH):
    print(f"Error: Model or image encoder path not found: {BASE_MODEL_PATH}")
    exit()

IP_ADAPTER_REPO_ID = "kkkkggg/IMAGHarmony"
IP_ADAPTER_FILENAME = "IMAGHarmony_variant1.bin"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Downloading weights file {IP_ADAPTER_FILENAME} from repository {IP_ADAPTER_REPO_ID}...")
try:
    fine_tuned_ckpt_path = hf_hub_download(
        repo_id=IP_ADAPTER_REPO_ID,
        filename=IP_ADAPTER_FILENAME,
    )
    print(f"Weights file downloaded to: {fine_tuned_ckpt_path}")
except Exception as e:
    print(f"Failed to download weights: {e}")
    exit()

print(f"Loading Stable Diffusion XL pipeline from local path: {BASE_MODEL_PATH}")
pipe = StableDiffusionXLPipeline.from_pretrained(
    BASE_MODEL_PATH,
    torch_dtype=torch.float16,
    add_watermarker=False,
).to(DEVICE)
pipe.enable_vae_tiling()

print("Instantiating custom ComposedAttention module...")
number_class_crossattention = ComposedAttention(
    image_hidden_size=1280,
    text_context_dim=2048,
    inter_dim=CKPT_INTER_DIM,
    cross_heads=CKPT_CROSS_HEADS,
    reshape_blocks=CKPT_RESHAPE_BLOCKS,
    cross_value_dim=CKPT_CROSS_VALUE_DIM,
    scale=1.0
).to(DEVICE).half()

print("Extracting and loading ComposedAttention weights from the main checkpoint...")
try:
    state_dict = torch.load(fine_tuned_ckpt_path, map_location="cpu")
    composed_attention_weights = state_dict["composed_adapter"]
    number_class_crossattention.load_state_dict(composed_attention_weights)
    print("Successfully loaded fine-tuned ComposedAttention weights.")
except KeyError:
    print(f"Error: Key 'composed_adapter' not found in weights file {IP_ADAPTER_FILENAME}.")
    exit()
except Exception as e:
    print(f"An unknown error occurred while loading ComposedAttention weights: {e}")
    exit()

print("Initializing IP-Adapter...")
ip_model = IPAdapterXL(
    pipe,
    IMAGE_ENCODER_PATH,
    fine_tuned_ckpt_path,
    DEVICE,
    target_blocks=["down_blocks.1.attentions.1"],
    num_tokens=4,
    inference=True,
    number_class_crossattention=number_class_crossattention
)

print("Models loaded. Gradio application is ready!")


def generate_image(uploaded_image: Image.Image, local_path: str, save_path: str,
                   prompt: str, extra_text: str, negative_prompt: str,
                   guidance_scale: float, num_inference_steps: int, seed: int, progress=gr.Progress()):
    
    pil_image = None
    if uploaded_image is not None:
        pil_image = uploaded_image
    elif local_path and local_path.strip():
        try:
            pil_image = Image.open(local_path.strip())
        except FileNotFoundError:
            raise gr.Error(f"File not found. Please check the path: {local_path.strip()}")
        except Exception as e:
            raise gr.Error(f"Cannot open image file. Error: {e}")
    else:
        raise gr.Error("Please upload a reference image or provide a valid local file path!")

    input_image = pil_image.resize((512, 512))
    progress(0, desc="Generating image...")

    images = ip_model.generate(
        pil_image=input_image,
        prompt=prompt,
        negative_prompt=negative_prompt,
        scale=1.0,
        guidance_scale=guidance_scale,
        num_samples=1,
        num_inference_steps=int(num_inference_steps),
        seed=int(seed),
        extra_text=extra_text,
        number_class_crossattention=number_class_crossattention
    )
    generated_image = images[0]
    progress(1, desc="Generation complete!")

    if save_path and save_path.strip():
        try:
            save_dir = save_path.strip()
            os.makedirs(save_dir, exist_ok=True)
            
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            filename = f"output_{timestamp}_seed{seed}.png"
            full_path = os.path.join(save_dir, filename)
            
            generated_image.save(full_path)
            gr.Info(f"Image successfully saved to: {full_path}")
        except Exception as e:
            gr.Warning(f"Could not save the image! Error: {e}")
            print(f"Error saving image: {e}")

    return generated_image

with gr.Blocks() as demo:
    gr.Markdown("# IMAGHarmony: Image Generation Demo")
    gr.Markdown(
        "**Upload a reference image from your computer, or enter the full local path in the text box below.**\n"
        "Then, enter a **Target Prompt** and a **Reference Content** description to generate a new image."
    )

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload Your Reference Image")
            local_path_input = gr.Textbox(
                label="Or Enter Local Image Path",
                placeholder="/home/user/images/photo.jpg",
                info="If an image is uploaded, it will be prioritized over the path."
            )
            prompt = gr.Textbox(label="Target Prompt", value="four cats")
            extra_text = gr.Textbox(
                label="Reference Content",
                info="Enter text that describes the reference image, typically the caption used during training.",
                value="four dogs"
            )
            neg_prompt = gr.Textbox(
                label="Negative Prompt",
                value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry"
            )
            save_path_input = gr.Textbox(
                label="Save to Local Directory (Optional)",
                placeholder="/your/path",
                info="If left empty, the image will not be saved."
            )
            run_button = gr.Button("Generate Image", variant="primary")

        with gr.Column(scale=1):
            output_image = gr.Image(type="pil", label="Generated Image")

    with gr.Accordion("Advanced Settings", open=False):
        guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=10.0, label="Guidance Scale")
        num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
        seed = gr.Slider(minimum=0, maximum=999999, step=1, value=8, label="Seed", randomize=True)

    run_button.click(
        fn=generate_image,
        inputs=[input_image, local_path_input, save_path_input, prompt, extra_text, neg_prompt, guidance_scale, num_inference_steps, seed],
        outputs=output_image
    )

demo.launch(share=True)

================================================
FILE: ip_adapter/__init__.py
================================================
from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull


__all__ = [
    "IPAdapter",
    "IPAdapterPlus",
    "IPAdapterPlusXL",
    "IPAdapterXL",
    "IPAdapterFull",
    
]


================================================
FILE: ip_adapter/attention_processor.py
================================================
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Cross_Attention(nn.Module):
    def __init__(self,
                 query_dim,         # Input dimension for Q projection
                 context_dim,       # Input dimension for K/V projection
                 heads=8,
                 value_dim=None,    # Dimension of V after projection (defaults to head_dim)
                 out_dim=None):     # Output dimension
        super().__init__()
        self.query_dim = query_dim
        self.heads = heads
        self.head_dim = self.query_dim // self.heads
        self.scale = math.sqrt(self.head_dim)
        self.value_dim = value_dim if value_dim is not None else self.head_dim
        self.out_dim = out_dim if out_dim is not None else heads * self.value_dim

        # Linear projection layers
        self.to_q = nn.Linear(query_dim, self.heads * self.head_dim)
        self.to_k = nn.Linear(context_dim, self.heads * self.head_dim)
        self.to_v = nn.Linear(context_dim, self.heads * self.value_dim)

        # Optional output projection
        self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)

    def forward(self, query_input, context_input):

        B = query_input.size(0)

        # Project Q, K, V
        q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, Q_len, head_dim]
        k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, K_len, head_dim]
        v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2)  # [B, heads, K_len, v_dim]

        # Attention scores (weights)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # [B, heads, Q_len, K_len]
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Weighted sum
        attn_output = torch.matmul(attn_probs, v)  # [B, heads, Q_len, v_dim]

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim)  # [B, Q_len, heads * v_dim]

        # Output projection
        output = self.out_proj(attn_output)  # [B, Q_len, out_dim]
        return output
    
    
    
class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class IPAttnProcessor(nn.Module):
    r"""
    Attention processor for IP-Adapater.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens
        self.skip = skip

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        if not self.skip:
            # for ip-adapter
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)

            ip_key = attn.head_to_batch_dim(ip_key)
            ip_value = attn.head_to_batch_dim(ip_value)

            ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
            self.attn_map = ip_attention_probs
            ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
            ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

            hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class IPAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens
        self.skip = skip

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        if not self.skip:
            # for ip-adapter
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)

            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

            # the output of sdp = (batch, num_heads, seq_len, head_dim)
            # TODO: add support for attn.scale when we move to Torch 2.1
            ip_hidden_states = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            with torch.no_grad():
                self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
                #print(self.attn_map.shape)

            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states = ip_hidden_states.to(query.dtype)

            hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


## for controlnet
class CNAttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(self, num_tokens=4):
        self.num_tokens = num_tokens

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class CNAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self, num_tokens=4):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.num_tokens = num_tokens

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


================================================
FILE: ip_adapter/custom_pipelines.py
================================================
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg

from .utils import is_torch2_available

if is_torch2_available():
    from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
else:
    from .attention_processor import IPAttnProcessor


class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
    def set_scale(self, scale):
        for attn_processor in self.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.scale = scale

    @torch.no_grad()
    def __call__(  # noqa: C901
        self,
        prompt: Optional[Union[str, List[str]]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        denoising_end: Optional[float] = None,
        guidance_scale: float = 5.0,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        original_size: Optional[Tuple[int, int]] = None,
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        target_size: Optional[Tuple[int, int]] = None,
        negative_original_size: Optional[Tuple[int, int]] = None,
        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
        negative_target_size: Optional[Tuple[int, int]] = None,
        control_guidance_start: float = 0.0,
        control_guidance_end: float = 1.0,
    ):
        r"""
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            prompt_2 (`str` or `List[str]`, *optional*):
                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
                used in both text-encoders
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image. This is set to 1024 by default for the best results.
                Anything below 512 pixels won't work well for
                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
                and checkpoints that are not specifically fine-tuned on low resolutions.
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image. This is set to 1024 by default for the best results.
                Anything below 512 pixels won't work well for
                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
                and checkpoints that are not specifically fine-tuned on low resolutions.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            denoising_end (`float`, *optional*):
                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
                completed before it is intentionally prematurely terminated. As a result, the returned sample will
                still retain a substantial amount of noise as determined by the discrete timesteps selected by the
                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
                "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
            guidance_scale (`float`, *optional*, defaults to 5.0):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            negative_prompt_2 (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
                If not provided, pooled text embeddings will be generated from `prompt` input argument.
            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
                input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
                of a plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            guidance_rescale (`float`, *optional*, defaults to 0.7):
                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
                Guidance rescale factor should fix overexposure when using zero terminal SNR.
            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
                explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
                For most cases, `target_size` should be set to the desired height and width of the generated image. If
                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
                micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
                micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
                To negatively condition the generation process based on a target image resolution. It should be as same
                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
            control_guidance_start (`float`, *optional*, defaults to 0.0):
                The percentage of total steps at which the ControlNet starts applying.
            control_guidance_end (`float`, *optional*, defaults to 1.0):
                The percentage of total steps at which the ControlNet stops applying.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
            `tuple`. When returning a tuple, the first element is a list with the generated images.
        """
        # 0. Default height and width to unet
        height = height or self.default_sample_size * self.vae_scale_factor
        width = width or self.default_sample_size * self.vae_scale_factor

        original_size = original_size or (height, width)
        target_size = target_size or (height, width)

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            prompt_2,
            height,
            width,
            callback_steps,
            negative_prompt,
            negative_prompt_2,
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.encode_prompt(
            prompt=prompt,
            prompt_2=prompt_2,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            negative_prompt_2=negative_prompt_2,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)

        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Prepare added time ids & embeddings
        add_text_embeds = pooled_prompt_embeds
        if self.text_encoder_2 is None:
            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
        else:
            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim

        add_time_ids = self._get_add_time_ids(
            original_size,
            crops_coords_top_left,
            target_size,
            dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=text_encoder_projection_dim,
        )
        if negative_original_size is not None and negative_target_size is not None:
            negative_add_time_ids = self._get_add_time_ids(
                negative_original_size,
                negative_crops_coords_top_left,
                negative_target_size,
                dtype=prompt_embeds.dtype,
                text_encoder_projection_dim=text_encoder_projection_dim,
            )
        else:
            negative_add_time_ids = add_time_ids

        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(device)
        add_text_embeds = add_text_embeds.to(device)
        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

        # 8. Denoising loop
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

        # 7.1 Apply denoising_end
        if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
            discrete_timestep_cutoff = int(
                round(
                    self.scheduler.config.num_train_timesteps
                    - (denoising_end * self.scheduler.config.num_train_timesteps)
                )
            )
            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
            timesteps = timesteps[:num_inference_steps]

        # get init conditioning scale
        for attn_processor in self.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                conditioning_scale = attn_processor.scale
                break

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end):
                    self.set_scale(0.0)
                else:
                    self.set_scale(conditioning_scale)

                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                if do_classifier_free_guidance and guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        if not output_type == "latent":
            # make sure the VAE is in float32 mode, as it overflows in float16
            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

            if needs_upcasting:
                self.upcast_vae()
                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

            # cast back to fp16 if needed
            if needs_upcasting:
                self.vae.to(dtype=torch.float16)
        else:
            image = latents

        if output_type != "latent":
            # apply watermark if available
            if self.watermark is not None:
                image = self.watermark.apply_watermark(image)

            image = self.image_processor.postprocess(image, output_type=output_type)

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image,)

        return StableDiffusionXLPipelineOutput(images=image)


================================================
FILE: ip_adapter/ip_adapter.py
================================================
import os
from typing import List

import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from  tutorial_train_sdxl_ori import HarmonyAttention
from .utils import is_torch2_available, get_generator

if is_torch2_available():
    from .attention_processor import (
        AttnProcessor2_0 as AttnProcessor,
    )
    from .attention_processor import (
        CNAttnProcessor2_0 as CNAttnProcessor,
    )
    from .attention_processor import (
        IPAttnProcessor2_0 as IPAttnProcessor,
    )
else:
    from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .resampler import Resampler


class ImageProjModel(torch.nn.Module):


    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim  
        self.clip_extra_context_tokens = clip_extra_context_tokens  
 
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)  

    def forward(self, image_embeds):
 
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class MLPProjModel(torch.nn.Module):

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
        super().__init__()
        

        self.proj = torch.nn.Sequential(
            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
            torch.nn.GELU(),
            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
            torch.nn.LayerNorm(cross_attention_dim)
        )
        
    def forward(self, image_embeds):
        clip_extra_context_tokens = self.proj(image_embeds)
        return clip_extra_context_tokens


class IPAdapter:
    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=None, number_class_crossattention=None):
        self.device = device
        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.num_tokens = num_tokens
        # self.target_blocks = target_blocks or ["down_blocks.2.attentions.1"] 

        self.pipe = sd_pipe.to(self.device)
        self.set_ip_adapter()

        # load image encoder
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
            self.device, dtype=torch.float16
        )
        self.clip_image_processor = CLIPImageProcessor()
        self.number_class_crossattention=number_class_crossattention.to(self.device, dtype=torch.float16)
        # image proj model
        self.image_proj_model = self.init_proj()

        self.load_ip_adapter()

    def init_proj(self):
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=self.image_encoder.config.projection_dim,
            clip_extra_context_tokens=self.num_tokens,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    def set_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            
            if cross_attention_dim is None:
                attn_procs[name] = AttnProcessor()
            else:

                if 'down_blocks.2.attentions.1' in name:
                    attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                                                    num_tokens=self.num_tokens, skip=False).to(self.device, dtype=torch.float16)
                    
                else:
                    attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                                                    num_tokens=self.num_tokens, skip=True).to(self.device, dtype=torch.float16)
        
        unet.set_attn_processor(attn_procs)
        

        if hasattr(self.pipe, "controlnet"):
            if isinstance(self.pipe.controlnet, MultiControlNetModel):
                for controlnet in self.pipe.controlnet.nets:
                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
            else:
                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
    
    def load_ip_adapter(self):
        if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
            state_dict = {"image_proj_model": {}, "ip_adapter": {}, 'composed_modules': {} }
            with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
                for key in f.keys():
                    # if 'unet.up_blocks' not in key  and  'unet.mid_blocks' not in key:
                    #     print(key)
                    if key.startswith("image_proj."):
                        state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
                    elif key.startswith("adapter_modules."):
                        state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
                    elif key.startswith("composed_modules."):
                        state_dict["composed_modules"][key.replace("composed_modules.", "")] = f.get_tensor(key)
        else:
            state_dict = torch.load(self.ip_ckpt, map_location="cpu")
        print(state_dict.keys())
        self.image_proj_model.load_state_dict(state_dict["image_proj"])
        self.number_class_crossattention.load_state_dict(state_dict["composed_adapter"])
        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
        ip_layers.load_state_dict(state_dict["ip_adapter"])
        
        
        
    @torch.inference_mode()
    def get_image_embeds(self, pil_image=None, clip_image_embeds=None, extra_prompt_embeds=None):
        if pil_image is not None:
            if isinstance(pil_image, Image.Image):
                pil_image = [pil_image]
            clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
        else:
            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
        
        
        # add composer
        if extra_prompt_embeds is not None:
            extra_prompt_embeds = extra_prompt_embeds.to(self.device,torch.float16)
            output= self.number_class_crossattention(extra_prompt_embeds, clip_image_embeds)
            clip_image_embeds = clip_image_embeds + output
            
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)   
        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
        return image_prompt_embeds, uncond_image_prompt_embeds

    def set_scale(self, scale):
        for attn_processor in self.pipe.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.scale = scale

    def generate(
        self,
        pil_image=None,
        clip_image_embeds=None,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        guidance_scale=7.5,
        num_inference_steps=30,
        **kwargs,
    ):
        self.set_scale(scale)

        if pil_image is not None:
            num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
        else:
            num_prompts = clip_image_embeds.size(0)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image=pil_image, clip_image_embeds=clip_image_embeds
        )
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
                prompt,
                device=self.device,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = get_generator(seed, self.device)

        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images



class IPAdapterXL(IPAdapter):
    """SDXL"""
    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, 
                 num_tokens=4, target_blocks=None, inference=False, number_class_crossattention=None):
        self.inference = inference  
        super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, 
                        num_tokens=num_tokens, target_blocks=target_blocks, number_class_crossattention=number_class_crossattention)

    def generate(
        self,
        pil_image,
        prompt=None,
        negative_prompt=None,
        extra_text=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        num_inference_steps=30,
        **kwargs,
    ):

        self.set_scale(scale)

        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts


        if extra_text is not None:
            with torch.inference_mode():
                (
                extra_prompt_embeds,
                extra_negative_prompt_embeds,
                extra_pooled_prompt_embeds,
                extra_negative_pooled_prompt_embeds,
            ) = self.pipe.encode_prompt(
                extra_text,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )     
            
            
        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, extra_prompt_embeds = extra_prompt_embeds)

        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
            ) = self.pipe.encode_prompt(
                prompt,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
 
            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)



            
            
        self.generator = get_generator(seed, self.device)
        
        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            num_inference_steps=num_inference_steps,
            generator=self.generator,
            **kwargs,
        ).images

        return images



class IPAdapterPlus(IPAdapter):


    def init_proj(self):


        image_proj_model = Resampler(
            dim=self.pipe.unet.config.cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=12,
            num_queries=self.num_tokens,
            embedding_dim=self.image_encoder.config.hidden_size,
            output_dim=self.pipe.unet.config.cross_attention_dim,
            ff_mult=4,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):

        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
        clip_image = clip_image.to(self.device, dtype=torch.float16)
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_clip_image_embeds = self.image_encoder(
            torch.zeros_like(clip_image), output_hidden_states=True
        ).hidden_states[-2]
        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
        return image_prompt_embeds, uncond_image_prompt_embeds


class IPAdapterFull(IPAdapterPlus):
    """IP-Adapter with full features"""

    def init_proj(self):
        image_proj_model = MLPProjModel(
            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=self.image_encoder.config.hidden_size,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model


class IPAdapterPlusXL(IPAdapter):
    """SDXL"""

    def init_proj(self):
        image_proj_model = Resampler(
            dim=1280,
            depth=4,
            dim_head=64,
            heads=20,
            num_queries=self.num_tokens,
            embedding_dim=self.image_encoder.config.hidden_size,
            output_dim=self.pipe.unet.config.cross_attention_dim,
            ff_mult=4,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, pil_image):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
        clip_image = clip_image.to(self.device, dtype=torch.float16)
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_clip_image_embeds = self.image_encoder(
            torch.zeros_like(clip_image), output_hidden_states=True
        ).hidden_states[-2]
        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
        return image_prompt_embeds, uncond_image_prompt_embeds

    def generate(
        self,
        pil_image,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        num_inference_steps=30,
        **kwargs,
    ):
        self.set_scale(scale)

        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
            ) = self.pipe.encode_prompt(
                prompt,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)

        generator = get_generator(seed, self.device)

        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images


================================================
FILE: ip_adapter/ip_adapter_origin.py
================================================
import os
from typing import List

import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from .utils import is_torch2_available, get_generator

if is_torch2_available():
    from .attention_processor import (
        AttnProcessor2_0 as AttnProcessor,
    )
    from .attention_processor import (
        CNAttnProcessor2_0 as CNAttnProcessor,
    )
    from .attention_processor import (
        IPAttnProcessor2_0 as IPAttnProcessor,
    )
else:
    from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .resampler import Resampler


class ImageProjModel(torch.nn.Module):
    """投影模型 - 将CLIP图像特征转换为适合UNet交叉注意力的格式"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim  # UNet交叉注意力的维度
        self.clip_extra_context_tokens = clip_extra_context_tokens  # 额外上下文token数量
        # 线性投影层,将CLIP嵌入转换为多个额外的上下文token
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)  # 标准化层

    def forward(self, image_embeds):
        # 投影CLIP图像嵌入到多个上下文token
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class MLPProjModel(torch.nn.Module):
    """使用多层感知器的投影模型 - 用于IPAdapterFull变体"""
    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
        super().__init__()
        
        # 多层感知器,包含两个线性层、GELU激活和层标准化
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
            torch.nn.GELU(),
            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
            torch.nn.LayerNorm(cross_attention_dim)
        )
        
    def forward(self, image_embeds):
        clip_extra_context_tokens = self.proj(image_embeds)
        return clip_extra_context_tokens


class IPAdapter:
    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
        self.device = device
        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.num_tokens = num_tokens

        self.pipe = sd_pipe.to(self.device)
        self.set_ip_adapter()

        # load image encoder
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
            self.device, dtype=torch.float16
        )
        self.clip_image_processor = CLIPImageProcessor()
        # image proj model
        self.image_proj_model = self.init_proj()

        self.load_ip_adapter()

    def init_proj(self):
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=self.image_encoder.config.projection_dim,
            clip_extra_context_tokens=self.num_tokens,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    def set_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            if cross_attention_dim is None:
                attn_procs[name] = AttnProcessor()
            else:
                attn_procs[name] = IPAttnProcessor(
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                    scale=1.0,
                    num_tokens=self.num_tokens,
                ).to(self.device, dtype=torch.float16)
        unet.set_attn_processor(attn_procs)
        if hasattr(self.pipe, "controlnet"):
            if isinstance(self.pipe.controlnet, MultiControlNetModel):
                for controlnet in self.pipe.controlnet.nets:
                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
            else:
                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
    
    def load_ip_adapter(self):
        if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
            state_dict = {"image_proj": {}, "ip_adapter": {}}
            with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
                for key in f.keys():
                    if key.startswith("image_proj."):
                        state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
                    elif key.startswith("ip_adapter."):
                        state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
        else:
            state_dict = torch.load(self.ip_ckpt, map_location="cpu")
        self.image_proj_model.load_state_dict(state_dict["image_proj"])
        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
        ip_layers.load_state_dict(state_dict["ip_adapter"])

    @torch.inference_mode()
    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
        if pil_image is not None:
            if isinstance(pil_image, Image.Image):
                pil_image = [pil_image]
            clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
        else:
            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
        return image_prompt_embeds, uncond_image_prompt_embeds

    def set_scale(self, scale):
        for attn_processor in self.pipe.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.scale = scale

    def generate(
        self,
        pil_image=None,
        clip_image_embeds=None,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        guidance_scale=7.5,
        num_inference_steps=30,
        **kwargs,
    ):
        self.set_scale(scale)

        if pil_image is not None:
            num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
        else:
            num_prompts = clip_image_embeds.size(0)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image=pil_image, clip_image_embeds=clip_image_embeds
        )
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
                prompt,
                device=self.device,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = get_generator(seed, self.device)

        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images



class IPAdapterXL(IPAdapter):
    """SDXL"""

    def generate(
        self,
        pil_image,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        num_inference_steps=30,
        **kwargs,
    ):
        """SDXL专用的生成方法,考虑了SDXL的pooled嵌入"""
        self.set_scale(scale)

        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
            ) = self.pipe.encode_prompt(
                prompt,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)

        self.generator = get_generator(seed, self.device)
        
        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            num_inference_steps=num_inference_steps,
            generator=self.generator,
            **kwargs,
        ).images

        return images


class IPAdapterPlus(IPAdapter):
    """使用细粒度特征的IP-Adapter增强版本"""

    def init_proj(self):
        """使用Resampler替代简单的线性投影"""
        # Resampler是一种更强大的特征重采样模型,能更好地捕获图像中的细节
        image_proj_model = Resampler(
            dim=self.pipe.unet.config.cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=12,
            num_queries=self.num_tokens,
            embedding_dim=self.image_encoder.config.hidden_size,
            output_dim=self.pipe.unet.config.cross_attention_dim,
            ff_mult=4,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
        """使用CLIP模型的倒数第二层隐藏状态作为更丰富的特征"""
        # 使用CLIP的内部特征而非最终投影,提供更细粒度的图像理解
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
        clip_image = clip_image.to(self.device, dtype=torch.float16)
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_clip_image_embeds = self.image_encoder(
            torch.zeros_like(clip_image), output_hidden_states=True
        ).hidden_states[-2]
        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
        return image_prompt_embeds, uncond_image_prompt_embeds


class IPAdapterFull(IPAdapterPlus):
    """IP-Adapter with full features"""

    def init_proj(self):
        image_proj_model = MLPProjModel(
            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=self.image_encoder.config.hidden_size,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model


class IPAdapterPlusXL(IPAdapter):
    """SDXL"""

    def init_proj(self):
        image_proj_model = Resampler(
            dim=1280,
            depth=4,
            dim_head=64,
            heads=20,
            num_queries=self.num_tokens,
            embedding_dim=self.image_encoder.config.hidden_size,
            output_dim=self.pipe.unet.config.cross_attention_dim,
            ff_mult=4,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, pil_image):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
        clip_image = clip_image.to(self.device, dtype=torch.float16)
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_clip_image_embeds = self.image_encoder(
            torch.zeros_like(clip_image), output_hidden_states=True
        ).hidden_states[-2]
        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
        return image_prompt_embeds, uncond_image_prompt_embeds

    def generate(
        self,
        pil_image,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        num_inference_steps=30,
        **kwargs,
    ):
        self.set_scale(scale)

        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
            ) = self.pipe.encode_prompt(
                prompt,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)

        generator = get_generator(seed, self.device)

        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images


================================================
FILE: ip_adapter/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py

import math

import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange


# FFN
def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


def reshape_tensor(x, heads):
    bs, length, width = x.shape
    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
    x = x.view(bs, length, heads, -1)
    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
    x = x.transpose(1, 2)
    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
    x = x.reshape(bs, heads, length, -1)
    return x


class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D)
        """
        x = self.norm1(x)
        latents = self.norm2(latents)

        b, l, _ = latents.shape

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q = reshape_tensor(q, self.heads)
        k = reshape_tensor(k, self.heads)
        v = reshape_tensor(v, self.heads)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v

        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)

        return self.to_out(out)


class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)

        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)


================================================
FILE: ip_adapter/shared_models.py
================================================
import torch
from torch import nn
import math
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch.nn as nn
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
import numpy as np
class Cross_Attention(nn.Module):
    def __init__(self, 
                 query_dim,         # Q projection input dimension
                 context_dim,       # K/V projection input dimension
                 heads=8, 
                 head_dim=64, 
                 value_dim=None,    # V dimension after dimensionality reduction (defaults to head_dim)
                 out_dim=None):     # Output dimension
        super().__init__()
        self.heads = heads
        self.head_dim = head_dim
        self.scale = math.sqrt(head_dim)
        self.value_dim = value_dim if value_dim is not None else head_dim
        self.out_dim = out_dim if out_dim is not None else heads * self.value_dim

        # Linear projection layers
        self.to_q = nn.Linear(query_dim, heads * head_dim)
        self.to_k = nn.Linear(context_dim, heads * head_dim)
        self.to_v = nn.Linear(context_dim, heads * self.value_dim)

        # Optional output projection
        self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)

    def forward(self, query_input, context_input):
        """
        query_input: [B, Q_len, query_dim]
        context_input: [B, K_len, context_dim]
        """
        B = query_input.size(0)

        # Project Q, K, V
        q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, Q_len, head_dim]
        k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, K_len, head_dim]
        v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2)  # [B, heads, K_len, v_dim]

        # Attention weights
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # [B, heads, Q_len, K_len]
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Weighted sum
        attn_output = torch.matmul(attn_probs, v)  # [B, heads, Q_len, v_dim]

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim)  # [B, Q_len, heads * v_dim]

        # Output projection
        output = self.out_proj(attn_output)  # [B, Q_len, out_dim]
        return output
class ImageProjModel(torch.nn.Module):
    """Projection model - converts CLIP image features into a format suitable for UNet cross-attention"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim  # UNet cross-attention dimension
        self.clip_extra_context_tokens = clip_extra_context_tokens  # Number of extra context tokens
        # Linear projection layer, converts CLIP embeddings into multiple extra context tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)  # Normalization layer

    def forward(self, image_embeds):
        # Project CLIP image embeddings into multiple context tokens
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens



class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention
    def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):
        super().__init__()
        
        # Cross Attention layer
        self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32)
        self.scale=scale
        # self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64)
        
        # Image projection from 1280 to 2560
        self.fc1=nn.Linear(hidden_size, hidden_size*2)
        
        
        # Layer Normalization layer
        self.ln = nn.LayerNorm(hidden_size)
        
        # FC layer 1280 to 1280
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        

    def forward(self, text_embeds,image_embeds):

        image_embeds=self.fc1(image_embeds) #[1, 2560]
      
        image_embeds=image_embeds.reshape(1,4,640)
        output = self.cross_attention(image_embeds, text_embeds) #[1,4,320]
        output=output.reshape(1,1280)
        
        # Normalize the output using Layer Normalization
        output=self.ln(output)
        
        #  FC layer [1,4,2048]->[1,4,2048]
        output=self.fc2(output)
        
        return output*self.scale
    
    def load_from_checkpoint(self, ckpt_path: str):
        from safetensors.torch import load_file
        from collections import OrderedDict
        
        # Load weights file
        weights = load_file(ckpt_path)
        
        # Initialize two dictionaries to store weights for different modules separately
        image_proj_weights = OrderedDict()
        attn_weights = OrderedDict()
        
        # Separate weights into different modules
        for k, v in weights.items():
            # Process image_proj_model weights (match two possible key name prefixes)
            if k.startswith("image_proj_model.") or k.startswith("image_proj."):
                new_key = k.replace("image_proj_model.", "").replace("image_proj.", "")
                if hasattr(self, "image_proj_model") and hasattr(self.image_proj_model, new_key.split('.')[0]):
                    image_proj_weights[new_key] = v
            
            # Process target attention layer weights (match two possible key name formats)
            elif "down_blocks.2.attentions.1" in k:
                # Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1
                new_key = k.replace("composed_modules.", "").replace("ip_adapter.", "")
                if hasattr(self, new_key.split('.')[0]):
                    attn_weights[new_key] = v
        
        # Load image_proj_model weights (strict mode)
        if image_proj_weights:
            self.image_proj_model.load_state_dict(image_proj_weights, strict=True)
            print(f"Loaded image_proj_model weights: {len(image_proj_weights)} params")
        
        # Load attention layer weights (non-strict mode)
        if attn_weights:
            # Create temporary ModuleDict to load weights
            temp_dict = {k: v for k, v in self.named_modules() 
                        if "down_blocks.2.attentions.1" in k}
            temp_model = torch.nn.ModuleDict(temp_dict)
            
            missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False)
            if missing:
                print(f"Missing keys in attention blocks: {missing}")
            if unexpected:
                print(f"Unexpected keys in attention blocks: {unexpected}")
            
            print(f"Loaded attention weights: {len(attn_weights)} params")
        
        print(f"Successfully loaded target modules from {ckpt_path}")
        return self


================================================
FILE: ip_adapter/test_resampler.py
================================================
import torch
from resampler import Resampler
from transformers import CLIPVisionModel

BATCH_SIZE = 2
OUTPUT_DIM = 1280
NUM_QUERIES = 8
NUM_LATENTS_MEAN_POOLED = 4  # 0 for no mean pooling (previous behavior)
APPLY_POS_EMB = True  # False for no positional embeddings (previous behavior)
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"


def main():
    image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
    embedding_dim = image_encoder.config.hidden_size
    print(f"image_encoder hidden size: ", embedding_dim)

    image_proj_model = Resampler(
        dim=1024,
        depth=2,
        dim_head=64,
        heads=16,
        num_queries=NUM_QUERIES,
        embedding_dim=embedding_dim,
        output_dim=OUTPUT_DIM,
        ff_mult=2,
        max_seq_len=257,
        apply_pos_emb=APPLY_POS_EMB,
        num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
    )

    dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
    with torch.no_grad():
        image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
    print("image_embds shape: ", image_embeds.shape)

    with torch.no_grad():
        ip_tokens = image_proj_model(image_embeds)
    print("ip_tokens shape:", ip_tokens.shape)
    assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)


if __name__ == "__main__":
    main()


================================================
FILE: ip_adapter/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

attn_maps = {}
def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook

def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split('.')[-1].startswith('attn2'):
            module.register_forward_hook(hook_fn(name))

    return unet

def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1,0)
    temp_size = None

    for i in range(0,5):
        scale = 2 ** i
        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32),
        size=target_size,
        mode='bilinear',
        align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):

    idx = 0 if instance_or_negative else 1
    net_attn_maps = []

    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
        attn_map = upscale(attn_map, image_size) 
        net_attn_maps.append(attn_map) 

    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)

    return net_attn_maps

def attnmaps2images(net_attn_maps):

    #total_attn_scores = 0
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        #total_attn_scores += attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        #print("norm: ", normalized_attn_map.shape)
        image = Image.fromarray(normalized_attn_map)

        #image = fix_save_attn_map(attn_map)
        images.append(image)

    #print(total_attn_scores)
    return images
def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

def get_generator(seed, device):

    if seed is not None:
        if isinstance(seed, list):
            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator

================================================
FILE: requirements.txt
================================================
absl-py==2.1.0
accelerate==1.0.1
aiofiles==23.2.1
aiohappyeyeballs==2.4.4
aiohttp==3.10.11
aiosignal==1.3.1
annotated-types==0.7.0
anyio==3.7.1
async-timeout==5.0.1
asyncer==0.0.2
attrs==25.3.0
bidict==0.23.1
bitsandbytes==0.45.5
cachetools==5.5.0
certifi==2024.8.30
chainlit==1.1.402
charset-normalizer==3.4.0
chevron==0.14.0
click==8.1.8
contourpy==1.1.1
cycler==0.12.1
dataclasses-json==0.5.14
datasets==3.1.0
Deprecated==1.2.18
diffusers==0.30.0
dill==0.3.8
distro==1.9.0
einops==0.8.0
exceptiongroup==1.2.2
fastapi==0.110.3
ffmpy==0.5.0
filelock==3.16.1
filetype==1.2.0
fonttools==4.57.0
frozenlist==1.5.0
fsspec==2024.9.0
ftfy==6.2.3
google-auth==2.36.0
google-auth-oauthlib==1.0.0
googleapis-common-protos==1.69.2
gradio==4.44.1
gradio_client==1.3.0
grpcio==1.67.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
huggingface-hub==0.25.2
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.4.5
Jinja2==3.1.4
jiter==0.9.0
kiwisolver==1.4.7
Lazify==0.4.0
literalai==0.0.607
loguru==0.7.3
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.22.0
matplotlib==3.7.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
openai==1.71.0
opentelemetry-api==1.31.1
opentelemetry-exporter-otlp==1.31.1
opentelemetry-exporter-otlp-proto-common==1.31.1
opentelemetry-exporter-otlp-proto-grpc==1.31.1
opentelemetry-exporter-otlp-proto-http==1.31.1
opentelemetry-instrumentation==0.52b1
opentelemetry-proto==1.31.1
opentelemetry-sdk==1.31.1
opentelemetry-semantic-conventions==0.52b1
orjson==3.10.15
packaging==23.2
pandas==2.0.3
peft==0.13.2
pillow==10.4.0
propcache==0.2.0
protobuf==5.28.3
psutil==6.1.0
pyarrow==17.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pydantic==2.10.6
pydantic_core==2.27.2
pydub==0.25.1
Pygments==2.19.1
PyJWT==2.9.0
pyparsing==3.1.4
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-engineio==4.11.2
python-multipart==0.0.9
python-socketio==5.12.1
pytz==2024.2
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
requests-oauthlib==2.0.0
rich==14.0.0
rsa==4.9
ruff==0.11.8
safetensors==0.4.5
scipy==1.10.1
semantic-version==2.10.0
shellingham==1.5.4
simple-websocket==1.1.0
six==1.17.0
sniffio==1.3.1
sse-starlette==2.1.3
starlette==0.37.2
sympy==1.13.3
syncer==2.0.3
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
timm==1.0.13
tokenizers==0.20.3
tomli==2.2.1
tomlkit==0.12.0
torch==2.4.1
torchaudio==2.4.1
torchvision==0.19.1
tqdm==4.66.5
transformers==4.45.0
triton==3.0.0
typer==0.15.3
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
uptrace==1.31.0
urllib3==2.2.3
uvicorn==0.25.0
watchfiles==0.20.0
wcwidth==0.2.13
websockets==12.0
Werkzeug==3.0.6
wrapt==1.17.2
wsproto==1.2.0
xformers==0.0.28.post1
xxhash==3.5.0
yarl==1.15.2
zipp==3.20.2


================================================
FILE: run.sh
================================================
accelerate launch --gpu_ids 0 --num_processes 1 --mixed_precision "fp16" \
  train.py \
  --pretrained_model_name_or_path="your path" \
  --pretrained_ip_adapter_path="your path" \
  --image_encoder_path="your path" \
  --data_root_path='your path' \
  --mixed_precision="fp16" \
  --resolution=512 \
  --train_batch_size=1 \
  --dataloader_num_workers=4 \
  --learning_rate=2.5e-04 \
  --data_json_file="your path" \
  --weight_decay=0.01 \
  --output_dir="your path" \
  --save_steps=100 \
  --num_train_epochs 2100 \
  --composed_inter_dim=2560 \
  --composed_cross_heads=8 \
  --composed_reshape_blocks=8 \
  --composed_cross_value_dim=64


================================================
FILE: sdxl-fine-tuning/data/train.json
================================================
[
  {
  "image_file": "your image", 
  "text": " ", 
  "extra_text": "your caption", 
  "comments": {
      "image_file": "five_cats.png",
      "text": "",
      "extra_text": "Five cats"
    }
  }

]


================================================
FILE: shared_models.py
================================================
import torch
from torch import nn
import math
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch.nn as nn
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
import numpy as np
class Cross_Attention(nn.Module):
    def __init__(self, 
                 query_dim,         # Q projection input dimension
                 context_dim,       # K/V projection input dimension
                 heads=8, 
                 head_dim=64, 
                 value_dim=None,    # V dimension after dimensionality reduction (defaults to head_dim)
                 out_dim=None):     # Output dimension
        super().__init__()
        self.heads = heads
        self.head_dim = head_dim
        self.scale = math.sqrt(head_dim)
        self.value_dim = value_dim if value_dim is not None else head_dim
        self.out_dim = out_dim if out_dim is not None else heads * self.value_dim

        # Linear projection layers
        self.to_q = nn.Linear(query_dim, heads * head_dim)
        self.to_k = nn.Linear(context_dim, heads * head_dim)
        self.to_v = nn.Linear(context_dim, heads * self.value_dim)

        # Optional output projection
        self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)

    def forward(self, query_input, context_input):
        """
        query_input: [B, Q_len, query_dim]
        context_input: [B, K_len, context_dim]
        """
        B = query_input.size(0)

        # Project Q, K, V
        q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, Q_len, head_dim]
        k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, K_len, head_dim]
        v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2)  # [B, heads, K_len, v_dim]

        # Attention weights
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # [B, heads, Q_len, K_len]
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Weighted sum
        attn_output = torch.matmul(attn_probs, v)  # [B, heads, Q_len, v_dim]

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim)  # [B, Q_len, heads * v_dim]

        # Output projection
        output = self.out_proj(attn_output)  # [B, Q_len, out_dim]
        return output
class ImageProjModel(torch.nn.Module):
    """Projection model - converts CLIP image features into a format suitable for UNet cross-attention"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim  # UNet cross-attention dimension
        self.clip_extra_context_tokens = clip_extra_context_tokens  # Number of extra context tokens
        # Linear projection layer, converts CLIP embeddings into multiple extra context tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)  # Normalization layer

    def forward(self, image_embeds):
        # Project CLIP image embeddings into multiple context tokens
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens



class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention
    def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):
        super().__init__()
        
        # Cross Attention layer
        self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32)
        self.scale=scale
        # self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64)
        
        # Image projection from 1280 to 2560
        self.fc1=nn.Linear(hidden_size, hidden_size*2)
        
        
        # Layer Normalization layer
        self.ln = nn.LayerNorm(hidden_size)
        
        # FC layer 1280 to 1280
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        

    def forward(self, text_embeds,image_embeds):

        image_embeds=self.fc1(image_embeds) #[1, 2560]
      
        image_embeds=image_embeds.reshape(1,4,640)
        output = self.cross_attention(image_embeds, text_embeds) #[1,4,320]
        output=output.reshape(1,1280)
        
        # Normalize the output using Layer Normalization
        output=self.ln(output)
        
        #  FC layer [1,4,2048]->[1,4,2048]
        output=self.fc2(output)
        
        return output*self.scale
    
    def load_from_checkpoint(self, ckpt_path: str):
        from safetensors.torch import load_file
        from collections import OrderedDict
        
        # Load weights file
        weights = load_file(ckpt_path)
        
        # Initialize two dictionaries to store weights for different modules separately
        image_proj_weights = OrderedDict()
        attn_weights = OrderedDict()
        
        # Separate weights into different modules
        for k, v in weights.items():
            # Process image_proj_model weights (match two possible key name prefixes)
            if k.startswith("image_proj_model.") or k.startswith("image_proj."):
                new_key = k.replace("image_proj_model.", "").replace("image_proj.", "")
                if hasattr(self, "image_proj_model") and hasattr(self.image_proj_model, new_key.split('.')[0]):
                    image_proj_weights[new_key] = v
            
            # Process target attention layer weights (match two possible key name formats)
            elif "down_blocks.2.attentions.1" in k:
                # Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1
                new_key = k.replace("composed_modules.", "").replace("ip_adapter.", "")
                if hasattr(self, new_key.split('.')[0]):
                    attn_weights[new_key] = v
        
        # Load image_proj_model weights (strict mode)
        if image_proj_weights:
            self.image_proj_model.load_state_dict(image_proj_weights, strict=True)
            print(f"Loaded image_proj_model weights: {len(image_proj_weights)} params")
        
        # Load attention layer weights (non-strict mode)
        if attn_weights:
            # Create temporary ModuleDict to load weights
            temp_dict = {k: v for k, v in self.named_modules() 
                        if "down_blocks.2.attentions.1" in k}
            temp_model = torch.nn.ModuleDict(temp_dict)
            
            missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False)
            if missing:
                print(f"Missing keys in attention blocks: {missing}")
            if unexpected:
                print(f"Unexpected keys in attention blocks: {unexpected}")
            
            print(f"Loaded attention weights: {len(attn_weights)} params")
        
        print(f"Successfully loaded target modules from {ckpt_path}")
        return self


================================================
FILE: test.py
================================================
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from ip_adapter import IPAdapterXL # Assuming IPAdapterXL is correctly defined in ip_adapter
from train import HarmonyAttention # Assuming HarmonyAttention is correctly defined
import os


'''
The following parameters must be manually adjusted to match the training parameters.
'''
ckpt_inter_dim = 2560
ckpt_cross_heads = 8
ckpt_reshape_blocks = 8
ckpt_cross_value_dim = 64




# Image generation function
def generate_image(input_path, prompt, extra_text, output_path="output.png"):
    print(f"\nGenerating image for: {prompt} + {extra_text}")
    
    # Prepare input image
    input_image = Image.open(input_path).resize((512, 512))
    
    # Generate image
    images = ip_model.generate(
        pil_image=input_image,
        prompt=prompt,
        negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
        scale=1.0,
        guidance_scale=5.0,
        num_samples=1,
        num_inference_steps=30,
        seed=42,
        extra_text=extra_text,
        number_class_crossattention=number_class_crossattention # This should be the HarmonyAttention instance
    )
    
    # Save result
    images[0].save(output_path)
    print(f"Saved generated image to: {output_path}")
    return images[0]


if __name__ == "__main__":
    input_image = "your path to inputimage" # Replace with your actual input image path
    device = "cuda:2"  

    # Model path configuration
    base_model_path = "your path" # Replace with your actual base model path
    image_encoder_path = "your path" # Replace with your actual image encoder path

    fine_tuned_ckpt = "fine_tuned model path"  # Path to the fine-tuned weights (ip_adapter.bin or similar)

    save_root = os.path.join(
        'your path', # Replace with your desired save directory
        fine_tuned_ckpt.split('/')[3] if len(fine_tuned_ckpt.split('/')) > 3 else "default_folder1", 
        fine_tuned_ckpt.split('/')[4] if len(fine_tuned_ckpt.split('/')) > 4 else "default_folder2", 
    )
    # Create path (if it doesn't exist)
    os.makedirs(save_root, exist_ok=True)
    print(f"Save directory created at: {save_root}")
    
    # Load SDXL base model
    print("Loading base SDXL model...")
    pipe = StableDiffusionXLPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        add_watermarker=False,
    )
    pipe.enable_vae_tiling()
    pipe.to(device)


    # Define the fusion method to be used
    fusion_method = "cross_attention"  # Options: "cross_attention", "qformer", "mlp"

    # Initialize HarmonyAttention (this is the `number_class_crossattention` module)
    print("Initializing HarmonyAttention module...")
    number_class_crossattention = HarmonyAttention( # Renamed for clarity, this is your custom attention module
        image_hidden_size=1280,     # Keep unchanged or match your training
        text_context_dim=2048,      # Keep unchanged or match your training
        inter_dim=ckpt_inter_dim,
        cross_heads=ckpt_cross_heads,
        reshape_blocks=ckpt_reshape_blocks,
        cross_value_dim=ckpt_cross_value_dim,
        scale=1.0,                  # Keep unchanged or adjust as needed
        fusion_method=fusion_method  # Add fusion method selection
    ).to(device).half()

    # Initialize IP-Adapter
    print("Initializing IP-Adapter with target blocks...")
    ip_model = IPAdapterXL(
        pipe, 
        image_encoder_path, 
        fine_tuned_ckpt, 
        device,
        target_blocks=["down_blocks.2.attentions.1"],  # Or your specific target blocks
        num_tokens=4, # Or your specific number of tokens
        inference=True,
        number_class_crossattention=number_class_crossattention  # Pass the initialized HarmonyAttention module here
    )


    print("HarmonyAttention weights are expected to be loaded as part of the IP-Adapter checkpoint.")
 

    generate_image(
        input_path=input_image,
        prompt="lions",
        extra_text="eight sheep", # Use the caption from training
        output_path=os.path.join(save_root, os.path.basename(input_image)) # Safer way to construct output path
    )


================================================
FILE: train.py
================================================
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch.nn as nn
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
from safetensors import safe_open
from shared_models import ImageProjModel
from baseline import QFormer,MLP,AttentionFusionWrapper
from safetensors.torch import save_file, load_file
from ip_adapter.utils import is_torch2_available
if is_torch2_available():
    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor

else:
    from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
from ip_adapter.attention_processor import Cross_Attention


def count_model_params(model):
    return sum([p.numel() for p in model.parameters()]) / 1e6

# Dataset
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
        super().__init__()

        self.tokenizer = tokenizer
        self.tokenizer_2 = tokenizer_2
        self.size = size
        self.center_crop = center_crop
        self.i_drop_rate = i_drop_rate
        self.t_drop_rate = t_drop_rate
        self.ti_drop_rate = ti_drop_rate
        self.image_root_path = image_root_path
    
        self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]

        self.transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.clip_image_processor = CLIPImageProcessor()
        
    def __getitem__(self, idx):
        item = self.data[idx] 
        text = item["text"]
        text_extra = item['extra_text']
        image_file = item["image_file"]
        
        # read image
        raw_image = Image.open(os.path.join(self.image_root_path, image_file))
        
        # original size
        original_width, original_height = raw_image.size
        original_size = torch.tensor([original_height, original_width])
        
        image_tensor = self.transform(raw_image.convert("RGB"))
        # random crop
        delta_h = image_tensor.shape[1] - self.size
        delta_w = image_tensor.shape[2] - self.size
        assert not all([delta_h, delta_w])
        
        if self.center_crop:
            top = delta_h // 2
            left = delta_w // 2
        else:
            top = np.random.randint(0, delta_h + 1)
            left = np.random.randint(0, delta_w + 1)
        image = transforms.functional.crop(
            image_tensor, top=top, left=left, height=self.size, width=self.size
        )
        crop_coords_top_left = torch.tensor([top, left]) 

        clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
        
        # drop
        drop_image_embed = 0
        rand_num = random.random()
        if rand_num < self.i_drop_rate:
            drop_image_embed = 1
        elif rand_num < (self.i_drop_rate + self.t_drop_rate):
            text = ""
        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
            text = ""
            drop_image_embed = 1

        # get text and tokenize
        text_input_ids = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        text_input_ids_2 = self.tokenizer_2(
            text,
            max_length=self.tokenizer_2.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        #extra text
        text_extra_input_ids=self.tokenizer(
            text_extra,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        text_extra_input_ids_2 = self.tokenizer_2(
            text_extra,
            max_length=self.tokenizer_2.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        return {
            "image": image,
            "text_input_ids": text_input_ids,
            "text_input_ids_2": text_input_ids_2,
            "text_extra_input_ids":text_extra_input_ids,
            "text_extra_input_ids_2":text_extra_input_ids_2,
            "clip_image": clip_image,
            "drop_image_embed": drop_image_embed,
            "original_size": original_size,
            "crop_coords_top_left": crop_coords_top_left,
            "target_size": torch.tensor([self.size, self.size]),

        }
        
    
    def __len__(self):
        return len(self.data)
    

def collate_fn(data):
    images = torch.stack([example["image"] for example in data])
    text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
    text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
    clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
    drop_image_embeds = [example["drop_image_embed"] for example in data]
    original_size = torch.stack([example["original_size"] for example in data])
    crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
    target_size = torch.stack([example["target_size"] for example in data])

    #extra text
    text_extra_input_ids=torch.stack([example["text_extra_input_ids"] for example in data])
    text_extra_input_ids_2=torch.stack([example["text_extra_input_ids_2"] for example in data])
    
    return {
        "images": images,
        "text_input_ids": text_input_ids,
        "text_input_ids_2": text_input_ids_2,
        "clip_images": clip_images,
        "drop_image_embeds": drop_image_embeds,
        "original_size": original_size,
        "crop_coords_top_left": crop_coords_top_left,
        "target_size": target_size,
        "text_extra_input_ids":text_extra_input_ids,
        "text_extra_input_ids_2":text_extra_input_ids_2,
    }



class HarmonyAttention(nn.Module):
    def __init__(self,
                 image_hidden_size=1280,     # Input image feature dimension
                 text_context_dim=2048,      # Input text context feature dimension
                 inter_dim=2560,             # Intermediate projection dimension
                 cross_heads=10,             # Number of cross-attention heads
                 reshape_blocks=8,           # Number of image feature blocks
                 cross_value_dim=64,         # Value dimension per head after dimensionality reduction
                 scale=1.0,                  # Output scaling factor
                 fusion_method="qformer"): # Fusion method selection: mlp, cross_attention, qformer
        super().__init__()
        
        self.scale = scale
        self.reshape_blocks = reshape_blocks
        self.cross_query_dim = inter_dim // reshape_blocks
        self.fusion_method = fusion_method
        self.image_hidden_size = image_hidden_size
        self.text_context_dim = text_context_dim
        
        # Image projection required by all methods
        self.fc1 = nn.Linear(image_hidden_size, inter_dim)
        print(fusion_method)
        if fusion_method == "cross_attention":
            # 1. Cross-attention
            self.fusion_text_image = Cross_Attention(
                query_dim=self.cross_query_dim,
                context_dim=text_context_dim,
                heads=cross_heads,
                value_dim=cross_value_dim
            )
        elif fusion_method == "qformer":
            # 2. Q-Former TODO
            self.fusion_text_image = QFormer(
                 num_queries=16,
                 hidden_dim=self.cross_query_dim,
                 num_layers=1,
                 num_heads=cross_heads
            )
        elif fusion_method=='mlp': 
           #3. MLP TODO
           self.fusion_text_image = MLP(
               fused_dim=self.cross_query_dim
            )
        elif fusion_method=='gated-attention':
            #4. Gated-attention
            self.fusion_text_image = AttentionFusionWrapper(
               fused_dim=self.cross_query_dim
            )
     
        flattened_dim = cross_value_dim * cross_heads * reshape_blocks
        self.ln = nn.LayerNorm(flattened_dim)
        self.fc2 = nn.Linear(flattened_dim, image_hidden_size)



    def forward(self, text_embeds, image_embeds):
            """
            Args:
                text_embeds: [B, T, text_context_dim]
                image_embeds: [B, image_hidden_size]
            Returns:
                out: [B, image_hidden_size], image features fused with text conditions
            """
            B = image_embeds.size(0)

            # Map image features to intermediate dimension and reshape into multiple blocks
            x = self.fc1(image_embeds)  # [B, inter_dim]
            x = x.view(B, self.reshape_blocks, self.cross_query_dim)  # [B, N_blocks, query_dim]  

            # Cross-Attention: interaction between image blocks and text
            print(self.fusion_text_image)
            attended = self.fusion_text_image(x, text_embeds)  # [B, N_blocks, value_dim * heads]
            print(attended.shape)
            # Flatten, normalize, and project back
            attended = attended.view(B, -1)     # [B, flattened_dim]
            out = self.ln(attended)
            out = self.fc2(out) * self.scale

            return out


    

  

           
           
class IPAdapter(torch.nn.Module):
    """IP-Adapter"""
    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None, 
                 inter_dim=None,
                 cross_heads=None,
                 reshape_blocks=None,
                 cross_value_dim=None,
                 fusion_method="cross_attention"):  # Fusion method parameter
        super().__init__()
        self.unet = unet
        self.image_proj_model = image_proj_model
        self.adapter_modules = adapter_modules
        
        # Directly use the fusion_method parameter, no need to convert to a boolean flag
        self.composed_modules = HarmonyAttention(
            image_hidden_size=1280,     # Image feature dimension fixed
            text_context_dim=2048,      # Text context dimension fixed
            inter_dim=inter_dim,
            cross_heads=cross_heads,
            reshape_blocks=reshape_blocks,
            cross_value_dim=cross_value_dim,
            scale=1.0,                  # Scaling factor fixed
            fusion_method=fusion_method  # Directly pass the fusion method parameter
        )
        
        if ckpt_path is not None:
            self.load_from_checkpoint(ckpt_path)
    
    def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds, text_extra_embeds):
        composed_embeds = self.composed_modules(text_extra_embeds, image_embeds)
        image_embeds = image_embeds + composed_embeds
        
        ip_tokens = self.image_proj_model(image_embeds)
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        # Predict the noise residual
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
        return noise_pred

    def load_from_checkpoint(self, ckpt_path: str):
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        if os.path.splitext(ckpt_path)[-1] == ".safetensors":
            state_dict = {"image_proj": {}, "ip_adapter": {}}
            with safe_open(ckpt_path, framework="pt", device="cpu") as f:
                for key in f.keys():
                    if key.startswith("image_proj."):
                        state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
                    elif key.startswith("ip_adapter."):
                        state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
        else:
            state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for image_proj_model and adapter_modules
        self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
        
        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")
    

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--pretrained_ip_adapter_path",
        type=str,
        default=None,
        help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
    )
    parser.add_argument(
        "--data_json_file",
        type=str,
        default=None,
        required=True,
        help="Training data",
    )
    parser.add_argument(
        "--data_root_path",
        type=str,
        default="",
        required=True,
        help="Training data root path",
    )
    parser.add_argument(
        "--image_encoder_path",
        type=str,
        default=None,
        required=True,
        help="Path to CLIP image encoder",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="sd-ip_adapter",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images"
        ),
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Learning rate to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=10000)
    parser.add_argument(
        "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
        ),
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=2000,
        help=(
            "Save a checkpoint of the training state every X updates"
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument(
        "--composed_inter_dim",
        type=int,
        default=None,
        help="HarmonyAttention's intermediate projection dimension [e.g., 1280, 2560]."
    )
    parser.add_argument(
        "--composed_cross_heads",
        type=int,
        default=None,
        help="Number of cross-attention heads in HarmonyAttention [e.g., 8, 10]."
    )
    parser.add_argument(
        "--composed_reshape_blocks",
        type=int,
        default=None,
        help="Number of image feature blocks in HarmonyAttention [e.g., 4, 8]."
    )
    parser.add_argument(
        "--composed_cross_value_dim",
        type=int,
        default=None,
        help="Value dimension per head after dimensionality reduction in HarmonyAttention [e.g., 32, 64]."
    )
    
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args
    

def main():
    # Parse command-line arguments
    args = parse_args()
    logging_dir = Path(args.output_dir, args.logging_dir)

    # Configure accelerate for mixed-precision and distributed training
    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )
    
    # Create output directory
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

    # Load various pre-trained model components
    # Noise addition control, tokenizer, text encoder, VAE, UNet, etc.
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
    tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2")
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
    
    # Freeze base model parameters, only train the adapter part
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    text_encoder_2.requires_grad_(False)
    image_encoder.requires_grad_(False)
    
    # Initialize IP-Adapter model components
    # Image projection model maps CLIP image embeddings to UNet's cross-attention dimension
    num_tokens = 4  # Number of extra context tokens
    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=num_tokens,
    )
    
    # Initialize adapter attention processors
    # Create appropriate attention processors for each attention block in UNet
    # init adapter modules
    attn_procs = {}
    unet_sd = unet.state_dict()
    # Initialize UNet's attention + IP-Adapter's cross_attention parameters
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
            
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            # Layers that need to add IP join additional attention parameters
            if 'down_blocks.2.attentions.1' in name:
        
                weights = {
                    "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                    "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
                }  
                attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                                                   num_tokens=num_tokens, skip=False)
                
                attn_procs[name].load_state_dict(weights, strict=False)
            else:
                attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                                                   num_tokens=num_tokens, skip=True)

    # Load all attention processing into UNet
    unet.set_attn_processor(attn_procs)
    # Extract all attention layers into a list
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())


    # Create the complete IP-Adapter model
    ip_adapter = IPAdapter(
        unet,
        image_proj_model,
        adapter_modules,
        args.pretrained_ip_adapter_path,
        inter_dim=args.composed_inter_dim,           # Use command-line arguments
        cross_heads=args.composed_cross_heads,         # Use command-line arguments
        reshape_blocks=args.composed_reshape_blocks,   # Use command-line arguments
        cross_value_dim=args.composed_cross_value_dim, # Use command-line arguments
        fusion_method='cross_attention'
    )

    # ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)
    # Set data type for mixed-precision training
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
        
    # Move models to the appropriate device and convert to the appropriate data type
    vae.to(accelerator.device)  # VAE uses fp32 for better stability
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    text_encoder_2.to(accelerator.device, dtype=weight_dtype)
    image_encoder.to(accelerator.device, dtype=weight_dtype)
    
    # Set optimizer - only optimize IP-Adapter components
    params_to_opt = itertools.chain(ip_adapter.adapter_modules.parameters(), ip_adapter.composed_modules.parameters())
    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
    accelerator.print("Trainable parameters:  adapter_modules:{:.2f}M, composed_modules:{:.2f}M".format(
    count_model_params(ip_adapter.adapter_modules),
    count_model_params(ip_adapter.composed_modules)))
    # Create dataset and dataloader
    train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, image_root_path=args.data_root_path)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )
    
    # Prepare model, optimizer, and dataloader with accelerator
    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
    
    # Start training loop
    global_step = 0
    for epoch in range(0, args.num_train_epochs):
        begin = time.perf_counter()
        for step, batch in enumerate(train_dataloader):
            load_data_time = time.perf_counter() - begin
            with accelerator.accumulate(ip_adapter):
                # Convert images to latent space using VAE
                with torch.no_grad():
                    # SDXL's VAE uses fp32 for better numerical stability
                    latents = vae.encode(batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor
                    latents = latents.to(accelerator.device, dtype=weight_dtype)

                # Generate random noise to add to the latent representation
                noise = torch.randn_like(latents)
                if args.noise_offset:
                    # Use noise offset technique to improve training stability
                    noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(accelerator.device, dtype=weight_dtype)

                bsz = latents.shape[0]
                # Sample random timesteps for each image
                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise according to timesteps (forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
                # Get CLIP image embeddings
                with torch.no_grad():
                    image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
                
                # Apply image embedding dropout strategy
                image_embeds_ = []
                for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
                    if drop_image_embed == 1:
                        image_embeds_.append(torch.zeros_like(image_embed))
                    else:
                        image_embeds_.append(image_embed)
                image_embeds = torch.stack(image_embeds_)
            
                # Get text embeddings (SDXL uses two text encoders)
                with torch.no_grad():
                    encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True)
                    text_embeds = encoder_output.hidden_states[-2]
                    encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True)
                    pooled_text_embeds = encoder_output_2[0]
                    text_embeds_2 = encoder_output_2.hidden_states[-2]
                    text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1)  # Concatenate outputs of the two text encoders
                    
                    #extra text
                    encoder_extra_output = text_encoder(batch['text_extra_input_ids'].to(accelerator.device),output_hidden_states=True)
                    text_extra_embeds = encoder_extra_output.hidden_states[-2]
                    encoder_extra_output_2 = text_encoder_2(batch['text_extra_input_ids_2'].to(accelerator.device),output_hidden_states=True)
                    text_extra_embeds_2 = encoder_extra_output_2.hidden_states[-2]
                    text_extra_embeds=torch.concat([text_extra_embeds,text_extra_embeds_2],dim=-1)
                                        
                # Add extra conditions required by SDXL (image size and crop info)
                add_time_ids = [
                    batch["original_size"].to(accelerator.device),
                    batch["crop_coords_top_left"].to(accelerator.device),
                    batch["target_size"].to(accelerator.device),
                ]
                add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
                unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
                
                # Predict noise using IP-Adapter
                noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds, text_extra_embeds)
                
                # Calculate MSE loss (between predicted noise and actual added noise)
                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
            
                # Gather loss in distributed training
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
                
                # Backpropagation and optimizer step
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

                # Print training information
                if accelerator.is_main_process:
                    print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
                        epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
            
            global_step += 1
            
            # Save checkpoint periodically
            if global_step % args.save_steps == 0:
                save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                accelerator.save_state(save_path, safe_serialization=False)
            
            begin = time.perf_counter()
                
if __name__ == "__main__":
    main()
Download .txt
gitextract_ymnget5p/

├── LICENSE
├── README.md
├── baseline.py
├── convert_bin.py
├── demo.py
├── ip_adapter/
│   ├── __init__.py
│   ├── attention_processor.py
│   ├── custom_pipelines.py
│   ├── ip_adapter.py
│   ├── ip_adapter_origin.py
│   ├── resampler.py
│   ├── shared_models.py
│   ├── test_resampler.py
│   └── utils.py
├── requirements.txt
├── run.sh
├── sdxl-fine-tuning/
│   └── data/
│       └── train.json
├── shared_models.py
├── test.py
└── train.py
Download .txt
SYMBOL INDEX (142 symbols across 14 files)

FILE: baseline.py
  class QFormer (line 5) | class QFormer(nn.Module):
    method __init__ (line 6) | def __init__(self,
    method forward (line 31) | def forward(self, image_feat, text_feat):
  class MLP (line 67) | class MLP(nn.Module):
    method __init__ (line 68) | def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_hea...
    method forward (line 81) | def forward(self, image_feat, text_feat):
  class GatedAttentionFusion (line 103) | class GatedAttentionFusion(nn.Module):
    method __init__ (line 104) | def __init__(self, input_dim=768, hidden_dim=512):
    method forward (line 114) | def forward(self, img_feat, txt_feat):
  class AttentionFusionWrapper (line 126) | class AttentionFusionWrapper(nn.Module):
    method __init__ (line 127) | def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_hea...
    method forward (line 135) | def forward(self, image_feat, text_feat):

FILE: convert_bin.py
  function convert_checkpoint_to_ip_adapter (line 5) | def convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapt...

FILE: demo.py
  function generate_image (line 94) | def generate_image(uploaded_image: Image.Image, local_path: str, save_pa...

FILE: ip_adapter/attention_processor.py
  class Cross_Attention (line 12) | class Cross_Attention(nn.Module):
    method __init__ (line 13) | def __init__(self,
    method forward (line 35) | def forward(self, query_input, context_input):
  class AttnProcessor (line 60) | class AttnProcessor(nn.Module):
    method __init__ (line 65) | def __init__(
    method __call__ (line 72) | def __call__(
  class IPAttnProcessor (line 135) | class IPAttnProcessor(nn.Module):
    method __init__ (line 149) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
    method __call__ (line 161) | def __call__(
  class AttnProcessor2_0 (line 244) | class AttnProcessor2_0(torch.nn.Module):
    method __init__ (line 249) | def __init__(
    method __call__ (line 258) | def __call__(
  class IPAttnProcessor2_0 (line 335) | class IPAttnProcessor2_0(torch.nn.Module):
    method __init__ (line 349) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
    method __call__ (line 364) | def __call__(
  class CNAttnProcessor (line 469) | class CNAttnProcessor:
    method __init__ (line 474) | def __init__(self, num_tokens=4):
    method __call__ (line 477) | def __call__(self, attn, hidden_states, encoder_hidden_states=None, at...
  class CNAttnProcessor2_0 (line 534) | class CNAttnProcessor2_0:
    method __init__ (line 539) | def __init__(self, num_tokens=4):
    method __call__ (line 544) | def __call__(

FILE: ip_adapter/custom_pipelines.py
  class StableDiffusionXLCustomPipeline (line 16) | class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
    method set_scale (line 17) | def set_scale(self, scale):
    method __call__ (line 23) | def __call__(  # noqa: C901

FILE: ip_adapter/ip_adapter.py
  class ImageProjModel (line 28) | class ImageProjModel(torch.nn.Module):
    method __init__ (line 31) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 41) | def forward(self, image_embeds):
  class MLPProjModel (line 51) | class MLPProjModel(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
    method forward (line 64) | def forward(self, image_embeds):
  class IPAdapter (line 69) | class IPAdapter:
    method __init__ (line 70) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_t...
    method init_proj (line 91) | def init_proj(self):
    method set_ip_adapter (line 99) | def set_ip_adapter(self):
    method load_ip_adapter (line 135) | def load_ip_adapter(self):
    method get_image_embeds (line 159) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, ext...
    method set_scale (line 179) | def set_scale(self, scale):
    method generate (line 184) | def generate(
  class IPAdapterXL (line 249) | class IPAdapterXL(IPAdapter):
    method __init__ (line 251) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device,
    method generate (line 257) | def generate(
  class IPAdapterPlus (line 344) | class IPAdapterPlus(IPAdapter):
    method init_proj (line 347) | def init_proj(self):
    method get_image_embeds (line 363) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
  class IPAdapterFull (line 378) | class IPAdapterFull(IPAdapterPlus):
    method init_proj (line 381) | def init_proj(self):
  class IPAdapterPlusXL (line 389) | class IPAdapterPlusXL(IPAdapter):
    method init_proj (line 392) | def init_proj(self):
    method get_image_embeds (line 406) | def get_image_embeds(self, pil_image):
    method generate (line 419) | def generate(

FILE: ip_adapter/ip_adapter_origin.py
  class ImageProjModel (line 28) | class ImageProjModel(torch.nn.Module):
    method __init__ (line 31) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 41) | def forward(self, image_embeds):
  class MLPProjModel (line 51) | class MLPProjModel(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
    method forward (line 64) | def forward(self, image_embeds):
  class IPAdapter (line 69) | class IPAdapter:
    method __init__ (line 70) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_t...
    method init_proj (line 89) | def init_proj(self):
    method set_ip_adapter (line 97) | def set_ip_adapter(self):
    method load_ip_adapter (line 127) | def load_ip_adapter(self):
    method get_image_embeds (line 143) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
    method set_scale (line 155) | def set_scale(self, scale):
    method generate (line 160) | def generate(
  class IPAdapterXL (line 225) | class IPAdapterXL(IPAdapter):
    method generate (line 228) | def generate(
  class IPAdapterPlus (line 291) | class IPAdapterPlus(IPAdapter):
    method init_proj (line 294) | def init_proj(self):
    method get_image_embeds (line 310) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
  class IPAdapterFull (line 326) | class IPAdapterFull(IPAdapterPlus):
    method init_proj (line 329) | def init_proj(self):
  class IPAdapterPlusXL (line 337) | class IPAdapterPlusXL(IPAdapter):
    method init_proj (line 340) | def init_proj(self):
    method get_image_embeds (line 354) | def get_image_embeds(self, pil_image):
    method generate (line 367) | def generate(

FILE: ip_adapter/resampler.py
  function FeedForward (line 13) | def FeedForward(dim, mult=4):
  function reshape_tensor (line 23) | def reshape_tensor(x, heads):
  class PerceiverAttention (line 34) | class PerceiverAttention(nn.Module):
    method __init__ (line 35) | def __init__(self, *, dim, dim_head=64, heads=8):
    method forward (line 49) | def forward(self, x, latents):
  class Resampler (line 81) | class Resampler(nn.Module):
    method __init__ (line 82) | def __init__(
    method forward (line 127) | def forward(self, x):
  function masked_mean (line 150) | def masked_mean(t, *, dim, mask=None):

FILE: ip_adapter/shared_models.py
  class Cross_Attention (line 16) | class Cross_Attention(nn.Module):
    method __init__ (line 17) | def __init__(self,
    method forward (line 39) | def forward(self, query_input, context_input):
  class ImageProjModel (line 64) | class ImageProjModel(torch.nn.Module):
    method __init__ (line 67) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 77) | def forward(self, image_embeds):
  class Composed_Attention (line 88) | class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention
    method __init__ (line 89) | def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):
    method forward (line 108) | def forward(self, text_embeds,image_embeds):
    method load_from_checkpoint (line 124) | def load_from_checkpoint(self, ckpt_path: str):

FILE: ip_adapter/test_resampler.py
  function main (line 13) | def main():

FILE: ip_adapter/utils.py
  function hook_fn (line 7) | def hook_fn(name):
  function register_cross_attention_hook (line 15) | def register_cross_attention_hook(unet):
  function upscale (line 22) | def upscale(attn_map, target_size):
  function get_net_attn_map (line 46) | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=Fals...
  function attnmaps2images (line 61) | def attnmaps2images(net_attn_maps):
  function is_torch2_available (line 80) | def is_torch2_available():
  function get_generator (line 83) | def get_generator(seed, device):

FILE: shared_models.py
  class Cross_Attention (line 16) | class Cross_Attention(nn.Module):
    method __init__ (line 17) | def __init__(self,
    method forward (line 39) | def forward(self, query_input, context_input):
  class ImageProjModel (line 64) | class ImageProjModel(torch.nn.Module):
    method __init__ (line 67) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 77) | def forward(self, image_embeds):
  class Composed_Attention (line 88) | class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention
    method __init__ (line 89) | def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):
    method forward (line 108) | def forward(self, text_embeds,image_embeds):
    method load_from_checkpoint (line 124) | def load_from_checkpoint(self, ckpt_path: str):

FILE: test.py
  function generate_image (line 21) | def generate_image(input_path, prompt, extra_text, output_path="output.p...

FILE: train.py
  function count_model_params (line 35) | def count_model_params(model):
  class MyDataset (line 39) | class MyDataset(torch.utils.data.Dataset):
    method __init__ (line 41) | def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, cente...
    method __getitem__ (line 63) | def __getitem__(self, idx):
    method __len__ (line 155) | def __len__(self):
  function collate_fn (line 159) | def collate_fn(data):
  class HarmonyAttention (line 188) | class HarmonyAttention(nn.Module):
    method __init__ (line 189) | def __init__(self,
    method forward (line 243) | def forward(self, text_embeds, image_embeds):
  class IPAdapter (line 275) | class IPAdapter(torch.nn.Module):
    method __init__ (line 277) | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=...
    method forward (line 303) | def forward(self, noisy_latents, timesteps, encoder_hidden_states, une...
    method load_from_checkpoint (line 313) | def load_from_checkpoint(self, ckpt_path: str):
  function parse_args (line 344) | def parse_args():
  function main (line 485) | def main():
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (187K chars).
[
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 5585,
    "preview": "# IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout\n\n\n\n<a href='https://revive234.githu"
  },
  {
    "path": "baseline.py",
    "chars": 5607,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass QFormer(nn.Module):\n    def __init__(self, \n  "
  },
  {
    "path": "convert_bin.py",
    "chars": 3818,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\n\ndef convert_checkpoint_to_ip_adapter(pytorch_model_path, out"
  },
  {
    "path": "demo.py",
    "chars": 7223,
    "preview": "import gradio as gr\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom PIL import Image\nfrom ip_adapter i"
  },
  {
    "path": "ip_adapter/__init__.py",
    "chars": 216,
    "preview": "from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull\n\n\n__all__ = [\n    \"IPAdapt"
  },
  {
    "path": "ip_adapter/attention_processor.py",
    "chars": 23382,
    "preview": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport to"
  },
  {
    "path": "ip_adapter/custom_pipelines.py",
    "chars": 22574,
    "preview": "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom diffusers import StableDiffusion"
  },
  {
    "path": "ip_adapter/ip_adapter.py",
    "chars": 19304,
    "preview": "import os\nfrom typing import List\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.c"
  },
  {
    "path": "ip_adapter/ip_adapter_origin.py",
    "chars": 17215,
    "preview": "import os\nfrom typing import List\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.c"
  },
  {
    "path": "ip_adapter/resampler.py",
    "chars": 5059,
    "preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
  },
  {
    "path": "ip_adapter/shared_models.py",
    "chars": 7392,
    "preview": "import torch\nfrom torch import nn\nimport math\nimport os\nimport random\nimport argparse\nfrom pathlib import Path\nimport js"
  },
  {
    "path": "ip_adapter/test_resampler.py",
    "chars": 1406,
    "preview": "import torch\nfrom resampler import Resampler\nfrom transformers import CLIPVisionModel\n\nBATCH_SIZE = 2\nOUTPUT_DIM = 1280\n"
  },
  {
    "path": "ip_adapter/utils.py",
    "chars": 2831,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\nattn_maps = {}\ndef hook_fn(name):"
  },
  {
    "path": "requirements.txt",
    "chars": 3190,
    "preview": "absl-py==2.1.0\naccelerate==1.0.1\naiofiles==23.2.1\naiohappyeyeballs==2.4.4\naiohttp==3.10.11\naiosignal==1.3.1\nannotated-ty"
  },
  {
    "path": "run.sh",
    "chars": 643,
    "preview": "accelerate launch --gpu_ids 0 --num_processes 1 --mixed_precision \"fp16\" \\\n  train.py \\\n  --pretrained_model_name_or_pat"
  },
  {
    "path": "sdxl-fine-tuning/data/train.json",
    "chars": 202,
    "preview": "[\n  {\n  \"image_file\": \"your image\", \n  \"text\": \" \", \n  \"extra_text\": \"your caption\", \n  \"comments\": {\n      \"image_file\""
  },
  {
    "path": "shared_models.py",
    "chars": 7392,
    "preview": "import torch\nfrom torch import nn\nimport math\nimport os\nimport random\nimport argparse\nfrom pathlib import Path\nimport js"
  },
  {
    "path": "test.py",
    "chars": 4243,
    "preview": "import torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom PIL import Image\nfrom ip_adapter import IPAdapterXL # "
  },
  {
    "path": "train.py",
    "chars": 30777,
    "preview": "import os\nimport random\nimport argparse\nfrom pathlib import Path\nimport json\nimport itertools\nimport time\nimport torch.n"
  }
]

About this extraction

This page contains the full source code of the muzishen/IMAGHarmony GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (175.2 KB), approximately 41.7k tokens, and a symbol index with 142 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!