Repository: muzishen/IMAGHarmony
Branch: main
Commit: 46895e751ff6
Files: 20
Total size: 175.2 KB
Directory structure:
gitextract_ymnget5p/
├── LICENSE
├── README.md
├── baseline.py
├── convert_bin.py
├── demo.py
├── ip_adapter/
│ ├── __init__.py
│ ├── attention_processor.py
│ ├── custom_pipelines.py
│ ├── ip_adapter.py
│ ├── ip_adapter_origin.py
│ ├── resampler.py
│ ├── shared_models.py
│ ├── test_resampler.py
│ └── utils.py
├── requirements.txt
├── run.sh
├── sdxl-fine-tuning/
│ └── data/
│ └── train.json
├── shared_models.py
├── test.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout
## 🗓️ Release
- [2025/5/30] 🔥 We released the [technical report](https://arxiv.org/pdf/2506.01949) of IMAGHarmony.
- [2025/5/28] 🔥 We release the train and inference code of IMAGHarmony.
- [2025/5/17] 🎉 We launch the [project page](https://revive234.github.io/IMAGHarmony.github.io/) of IMAGHarmony.
## 💡 Introduction
IMAGHarmony tackles the challenge of controllable image editing in multi-object scenes, where existing models struggle with aligning object quantity and spatial layout.
To this end, IMAGHarmony introduces a structure-aware framework for quantity-and-layout consistent image editing (QL-Edit), enabling precise control over object count, category, and arrangement.
We propose a harmony aware (HA) mudule to jointly model object structure and semantics, and a preference-guided noise selection (PNS) strategy to stabilize generation by selecting semantically aligned initial noise.
Our method is trained and evaluated on HarmonyBench, a newly curated benchmark with diverse editing scenarios.

## 🚀 HarmonyBench Dataset Demo

## 🚀 Examples


### Dual-Category Editing

## 🔧 Requirements
- Python>=3.8
- [PyTorch>=2.0.0](https://pytorch.org/)
- cuda>=11.8
```
conda create --name IMAGHarmony python=3.8.18
conda activate IMAGHarmony
# Install requirements
pip install -r requirements.txt
```
## 🌐 Download Models
You can download our models from [Huggingface](https://huggingface.co/kkkkggg/IMAGHarmony). You can download the other component models from the original repository, as follows.
- [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
- [stable-diffusion-XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
## 🚀 How to train
```
# Please download the HarmonyBench data first or prepare your own images
# and modify the path in run.sh
## Write caption of your image in your train.json file
# start training
sh train.sh
```
## 🚀 How to test
```
#Please convert your checkpionts
python conver_bin.py
#Please fill in your path in test.py
#then run
python test.py
```
Or you may like to test it on gradio
```
python demo.py
```
## Acknowledgement
We would like to thank the contributors to the [Instantstyle](https://github.com/instantX-research/InstantStyle) and [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) repositories, for their open research and exploration.
The IMAGHarmony code is available for both academic and commercial use. Users are permitted to generate images using this tool, provided they comply with local laws and exercise responsible use. The developers disclaim all liability for any misuse or unlawful activity by users.
## Citation
If you find IMAGHarmony useful for your research and applications, please cite using this BibTeX:
```bibtex
@misc{shen2025imagharmonycontrollableimageediting,
title={IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout},
author={Fei Shen and Yutong Gao and Jian Yu and Xiaoyu Du and Jinhui Tang},
year={2025},
eprint={2506.01949},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2506.01949},
}
```
## 🕒 TODO List
- [x] Paper
- [x] Train Code
- [x] Inference Code
- [ ] HarmonyBench Dataset
- [ ] Model Weights
## 👉 **Our other projects:**
- [IMAGEdit](https://github.com/XWH-A/IMAGEdit): Training-Free Controllable Video Editing with Consistent Object Layout. [可控多目标视频编辑]
- [IMAGDressing](https://github.com/muzishen/IMAGDressing): Controllable dressing generation. [可控穿衣生成]
- [IMAGGarment](https://github.com/muzishen/IMAGGarment): Fine-grained controllable garment generation. [可控服装生成]
- [IMAGHarmony](https://github.com/muzishen/IMAGHarmony): Controllable image editing with consistent object layout. [可控多目标图像编辑]
- [IMAGPose](https://github.com/muzishen/IMAGPose): Pose-guided person generation with high fidelity. [可控多模式人物生成]
- [RCDMs](https://github.com/muzishen/RCDMs): Rich-contextual conditional diffusion for story visualization. [可控故事生成]
- [PCDMs](https://github.com/tencent-ailab/PCDMs): Progressive conditional diffusion for pose-guided image synthesis. [可控人物生成]
- [V-Express](https://github.com/tencent-ailab/V-Express/): Explores strong and weak conditional relationships for portrait video generation. [可控数字人生成]
- [FaceShot](https://github.com/open-mmlab/FaceShot/): Talkingface plugin for any character. [可控动漫数字人生成]
- [CharacterShot](https://github.com/Jeoyal/CharacterShot): Controllable and consistent 4D character animation framework. [可控4D角色生成]
- [StyleTailor](https://github.com/mahb-THU/StyleTailor): An Agent for personalized fashion styling. [个性化时尚Agent]
- [SignVip](https://github.com/umnooob/signvip/): Controllable sign language video generation. [可控手语生成]
## 📨 Contact
If you have any questions, please feel free to contact with us at shenfei140721@126.com and yutonggaokkk@njust.edu.cn.
================================================
FILE: baseline.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class QFormer(nn.Module):
def __init__(self,
hidden_dim=768,
num_queries=16,
num_layers=6,
num_heads=12,
image_feat_dim=320,
text_feat_dim=2048,
add_modality_embedding=True):
super(QFormer, self).__init__()
self.hidden_dim = hidden_dim
self.num_queries = num_queries
self.add_modality_embedding = add_modality_embedding
# Learnable query tokens: [1, num_queries, D]
self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim))
# Modality type embeddings (0=image, 1=text)
if add_modality_embedding:
self.modality_embed = nn.Embedding(2, hidden_dim)
self.image_proj = nn.Linear(image_feat_dim, hidden_dim)
self.text_proj = nn.Linear(text_feat_dim, hidden_dim)
# Transformer encoder: Q interacts with K/V (image + text)
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, image_feat, text_feat):
"""
image_feat: [B, T_img, D]
text_feat: [B, T_txt, D]
"""
B = image_feat.size(0)
print(image_feat.shape,text_feat.shape)
image_feat = self.image_proj(image_feat)
text_feat = self.text_proj(text_feat)
# Concatenate image + text features as K/V
kv = torch.cat([image_feat, text_feat], dim=1) # [B, T_img + T_txt, D]
# Add modality type embedding
if self.add_modality_embedding:
T_img = image_feat.size(1)
T_txt = text_feat.size(1)
modality_ids = torch.cat([
torch.zeros(T_img, dtype=torch.long),
torch.ones(T_txt, dtype=torch.long)
], dim=0).to(image_feat.device) # [T_img + T_txt]
modality_embed = self.modality_embed(modality_ids) # [T_img + T_txt, D]
kv = kv + modality_embed.unsqueeze(0) # broadcast to [B, T, D]
# Expand learnable query tokens to batch size
queries = self.query_tokens.expand(B, -1, -1) # [B, N_query, D]
# Q-Former: let queries attend to K/V
# Transformer requires concat(Q, K/V)
input_seq = torch.cat([queries, kv], dim=1) # [B, N_query + T, D]
output = self.transformer(input_seq) # [B, N_query + T, D]
# Return only the updated query tokens
return output[:, :self.num_queries, :] # [B, N_query, D]
class MLP(nn.Module):
def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header =16):
super().__init__()
self.image_proj = nn.Linear(image_dim, fused_dim)
self.text_proj = nn.Linear(text_dim, fused_dim)
self.num_header = num_header
self.mlp = nn.Sequential(
nn.Linear(2 * fused_dim, fused_dim),
nn.ReLU(),
nn.Linear(fused_dim, fused_dim),
nn.ReLU(),
nn.Linear(fused_dim,fused_dim*16)
)
self.fused_dim=fused_dim
def forward(self, image_feat, text_feat):
"""
image_feat: [B, T_img, image_dim]
text_feat: [B, T_txt, text_dim]
"""
image_repr = image_feat.mean(dim=1) # [B, image_dim]
text_repr = text_feat.mean(dim=1) # [B, text_dim]
image_proj = self.image_proj(image_repr) # [B, fused_dim]
text_proj = self.text_proj(text_repr) # [B, fused_dim]
fused = torch.cat([image_proj, text_proj], dim=-1) # [B, 2*fused_dim]
output = self.mlp(fused).reshape(-1,self.num_header,self.fused_dim) # [B, fused_dim]
return output
class GatedAttentionFusion(nn.Module):
def __init__(self, input_dim=768, hidden_dim=512):
super().__init__()
self.gate_mlp = nn.Sequential(
nn.Linear(2 * input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # alpha ∈ [0, 1]
)
def forward(self, img_feat, txt_feat):
"""
img_feat: [B, D]
txt_feat: [B, D]
"""
fused_input = torch.cat([img_feat, txt_feat], dim=-1) # [B, 2D]
alpha = self.gate_mlp(fused_input) # [B, 1]
fused = alpha * img_feat + (1 - alpha) * txt_feat # [B, D]
return fused
class AttentionFusionWrapper(nn.Module):
def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header=16):
super().__init__()
self.img_proj = nn.Linear(image_dim, fused_dim)
self.txt_proj = nn.Linear(text_dim, fused_dim)
self.fusion = GatedAttentionFusion(input_dim=fused_dim)
self.num_header = num_header
self.fused_dim = fused_dim
self.dim_transfer = nn.Linear(fused_dim,fused_dim*self.num_header)
def forward(self, image_feat, text_feat):
"""
image_feat: [B, T_img, 320]
text_feat: [B, T_txt, 2048]
"""
# Mean pooling
img_global = image_feat.mean(dim=1) # [B, 320]
txt_global = text_feat.mean(dim=1) # [B, 2048]
# Linear projection
img_proj = self.img_proj(img_global) # [B, 768]
txt_proj = self.txt_proj(txt_global) # [B, 768]
# Gated fusion
fused = self.fusion(img_proj, txt_proj) # [B, 768]
fused = self.dim_transfer(fused).reshape(-1,self.num_header,self.fused_dim)
return fused
================================================
FILE: convert_bin.py
================================================
import os
import torch
from collections import OrderedDict
def convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path):
if not os.path.exists(pytorch_model_path):
print(f" [Warning] Source file not found, skipping: {pytorch_model_path}")
return False
print(f" Converting: {pytorch_model_path}")
try:
sd = torch.load(pytorch_model_path, map_location="cpu")
image_proj_sd = OrderedDict()
ip_sd = OrderedDict()
composed_sd = OrderedDict()
for k in sd:
if k.startswith("image_proj_model."):
image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
elif k.startswith("adapter_modules."):
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
elif k.startswith("composed_modules."):
composed_sd[k.replace("composed_modules.", "")] = sd[k]
if not image_proj_sd and not ip_sd and not composed_sd:
print(f" [Warning] No expected keys (image_proj_model, adapter_modules, composed_modules) found in {pytorch_model_path}. Skipping save.")
return False
final_sd = {
"image_proj": image_proj_sd,
"ip_adapter": ip_sd,
'composed_adapter': composed_sd
}
torch.save(final_sd, output_ip_adapter_path)
print(f" Successfully saved: {output_ip_adapter_path}")
return True
except Exception as e:
print(f" [Error] Failed to convert {pytorch_model_path}: {e}")
return False
if __name__ == "__main__":
base_log_dir = "your fine_tuned model path"
total_converted = 0
total_skipped = 0
total_errors = 0
print(f"Starting conversion process in base directory: {base_log_dir}")
for training_run_dir_name in os.listdir(base_log_dir):
training_run_dir_path = os.path.join(base_log_dir, training_run_dir_name)
# Check if it's actually a directory
if os.path.isdir(training_run_dir_path):
print(f"\nProcessing training run: {training_run_dir_name}")
# Iterate through items inside the training run directory
for checkpoint_dir_name in os.listdir(training_run_dir_path):
if checkpoint_dir_name.startswith("checkpoint-") and \
os.path.isdir(os.path.join(training_run_dir_path, checkpoint_dir_name)):
checkpoint_dir_path = os.path.join(training_run_dir_path, checkpoint_dir_name)
print(f"- Found checkpoint directory: {checkpoint_dir_name}")
pytorch_model_path = os.path.join(checkpoint_dir_path, "pytorch_model.bin")
output_ip_adapter_path = os.path.join(checkpoint_dir_path, "ip_adapter.bin")
if os.path.exists(output_ip_adapter_path):
print(f" Output file already exists, skipping: {output_ip_adapter_path}")
total_skipped += 1
continue
success = convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path)
if success:
total_converted += 1
else:
if not os.path.exists(pytorch_model_path):
total_skipped += 1
else:
total_errors +=1
print("\n--- Conversion Summary ---")
print(f"Total checkpoints converted: {total_converted}")
print(f"Total checkpoints skipped (e.g., source missing): {total_skipped}")
print(f"Total errors during conversion: {total_errors}")
================================================
FILE: demo.py
================================================
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from ip_adapter import IPAdapterXL
from huggingface_hub import hf_hub_download
import os
import time
try:
from tutorial_train_sdxl_ori import ComposedAttention
except ImportError:
print("Error: Could not import ComposedAttention.")
print("Please ensure 'tutorial_train_sdxl_ori.py' is in the same directory as this script.")
exit()
print("Loading models, please wait...")
CKPT_INTER_DIM = 2560
CKPT_CROSS_HEADS = 8
CKPT_RESHAPE_BLOCKS = 8
CKPT_CROSS_VALUE_DIM = 64
BASE_MODEL_PATH = "/aigc_data_hdd/checkpoints/stable-diffusion-xl-base-1.0"
IMAGE_ENCODER_PATH = os.path.join(BASE_MODEL_PATH, "image_encoder")
if not os.path.exists(BASE_MODEL_PATH) or not os.path.exists(IMAGE_ENCODER_PATH):
print(f"Error: Model or image encoder path not found: {BASE_MODEL_PATH}")
exit()
IP_ADAPTER_REPO_ID = "kkkkggg/IMAGHarmony"
IP_ADAPTER_FILENAME = "IMAGHarmony_variant1.bin"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Downloading weights file {IP_ADAPTER_FILENAME} from repository {IP_ADAPTER_REPO_ID}...")
try:
fine_tuned_ckpt_path = hf_hub_download(
repo_id=IP_ADAPTER_REPO_ID,
filename=IP_ADAPTER_FILENAME,
)
print(f"Weights file downloaded to: {fine_tuned_ckpt_path}")
except Exception as e:
print(f"Failed to download weights: {e}")
exit()
print(f"Loading Stable Diffusion XL pipeline from local path: {BASE_MODEL_PATH}")
pipe = StableDiffusionXLPipeline.from_pretrained(
BASE_MODEL_PATH,
torch_dtype=torch.float16,
add_watermarker=False,
).to(DEVICE)
pipe.enable_vae_tiling()
print("Instantiating custom ComposedAttention module...")
number_class_crossattention = ComposedAttention(
image_hidden_size=1280,
text_context_dim=2048,
inter_dim=CKPT_INTER_DIM,
cross_heads=CKPT_CROSS_HEADS,
reshape_blocks=CKPT_RESHAPE_BLOCKS,
cross_value_dim=CKPT_CROSS_VALUE_DIM,
scale=1.0
).to(DEVICE).half()
print("Extracting and loading ComposedAttention weights from the main checkpoint...")
try:
state_dict = torch.load(fine_tuned_ckpt_path, map_location="cpu")
composed_attention_weights = state_dict["composed_adapter"]
number_class_crossattention.load_state_dict(composed_attention_weights)
print("Successfully loaded fine-tuned ComposedAttention weights.")
except KeyError:
print(f"Error: Key 'composed_adapter' not found in weights file {IP_ADAPTER_FILENAME}.")
exit()
except Exception as e:
print(f"An unknown error occurred while loading ComposedAttention weights: {e}")
exit()
print("Initializing IP-Adapter...")
ip_model = IPAdapterXL(
pipe,
IMAGE_ENCODER_PATH,
fine_tuned_ckpt_path,
DEVICE,
target_blocks=["down_blocks.1.attentions.1"],
num_tokens=4,
inference=True,
number_class_crossattention=number_class_crossattention
)
print("Models loaded. Gradio application is ready!")
def generate_image(uploaded_image: Image.Image, local_path: str, save_path: str,
prompt: str, extra_text: str, negative_prompt: str,
guidance_scale: float, num_inference_steps: int, seed: int, progress=gr.Progress()):
pil_image = None
if uploaded_image is not None:
pil_image = uploaded_image
elif local_path and local_path.strip():
try:
pil_image = Image.open(local_path.strip())
except FileNotFoundError:
raise gr.Error(f"File not found. Please check the path: {local_path.strip()}")
except Exception as e:
raise gr.Error(f"Cannot open image file. Error: {e}")
else:
raise gr.Error("Please upload a reference image or provide a valid local file path!")
input_image = pil_image.resize((512, 512))
progress(0, desc="Generating image...")
images = ip_model.generate(
pil_image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
scale=1.0,
guidance_scale=guidance_scale,
num_samples=1,
num_inference_steps=int(num_inference_steps),
seed=int(seed),
extra_text=extra_text,
number_class_crossattention=number_class_crossattention
)
generated_image = images[0]
progress(1, desc="Generation complete!")
if save_path and save_path.strip():
try:
save_dir = save_path.strip()
os.makedirs(save_dir, exist_ok=True)
timestamp = time.strftime("%Y%m%d-%H%M%S")
filename = f"output_{timestamp}_seed{seed}.png"
full_path = os.path.join(save_dir, filename)
generated_image.save(full_path)
gr.Info(f"Image successfully saved to: {full_path}")
except Exception as e:
gr.Warning(f"Could not save the image! Error: {e}")
print(f"Error saving image: {e}")
return generated_image
with gr.Blocks() as demo:
gr.Markdown("# IMAGHarmony: Image Generation Demo")
gr.Markdown(
"**Upload a reference image from your computer, or enter the full local path in the text box below.**\n"
"Then, enter a **Target Prompt** and a **Reference Content** description to generate a new image."
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Your Reference Image")
local_path_input = gr.Textbox(
label="Or Enter Local Image Path",
placeholder="/home/user/images/photo.jpg",
info="If an image is uploaded, it will be prioritized over the path."
)
prompt = gr.Textbox(label="Target Prompt", value="four cats")
extra_text = gr.Textbox(
label="Reference Content",
info="Enter text that describes the reference image, typically the caption used during training.",
value="four dogs"
)
neg_prompt = gr.Textbox(
label="Negative Prompt",
value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry"
)
save_path_input = gr.Textbox(
label="Save to Local Directory (Optional)",
placeholder="/your/path",
info="If left empty, the image will not be saved."
)
run_button = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Generated Image")
with gr.Accordion("Advanced Settings", open=False):
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=10.0, label="Guidance Scale")
num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
seed = gr.Slider(minimum=0, maximum=999999, step=1, value=8, label="Seed", randomize=True)
run_button.click(
fn=generate_image,
inputs=[input_image, local_path_input, save_path_input, prompt, extra_text, neg_prompt, guidance_scale, num_inference_steps, seed],
outputs=output_image
)
demo.launch(share=True)
================================================
FILE: ip_adapter/__init__.py
================================================
from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
__all__ = [
"IPAdapter",
"IPAdapterPlus",
"IPAdapterPlusXL",
"IPAdapterXL",
"IPAdapterFull",
]
================================================
FILE: ip_adapter/attention_processor.py
================================================
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Cross_Attention(nn.Module):
def __init__(self,
query_dim, # Input dimension for Q projection
context_dim, # Input dimension for K/V projection
heads=8,
value_dim=None, # Dimension of V after projection (defaults to head_dim)
out_dim=None): # Output dimension
super().__init__()
self.query_dim = query_dim
self.heads = heads
self.head_dim = self.query_dim // self.heads
self.scale = math.sqrt(self.head_dim)
self.value_dim = value_dim if value_dim is not None else self.head_dim
self.out_dim = out_dim if out_dim is not None else heads * self.value_dim
# Linear projection layers
self.to_q = nn.Linear(query_dim, self.heads * self.head_dim)
self.to_k = nn.Linear(context_dim, self.heads * self.head_dim)
self.to_v = nn.Linear(context_dim, self.heads * self.value_dim)
# Optional output projection
self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)
def forward(self, query_input, context_input):
B = query_input.size(0)
# Project Q, K, V
q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, Q_len, head_dim]
k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, K_len, head_dim]
v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2) # [B, heads, K_len, v_dim]
# Attention scores (weights)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [B, heads, Q_len, K_len]
attn_probs = F.softmax(attn_scores, dim=-1)
# Weighted sum
attn_output = torch.matmul(attn_probs, v) # [B, heads, Q_len, v_dim]
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim) # [B, Q_len, heads * v_dim]
# Output projection
output = self.out_proj(attn_output) # [B, Q_len, out_dim]
return output
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.skip = skip
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if not self.skip:
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.skip = skip
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if not self.skip:
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
## for controlnet
class CNAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, num_tokens=4):
self.num_tokens = num_tokens
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CNAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, num_tokens=4):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.num_tokens = num_tokens
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
================================================
FILE: ip_adapter/custom_pipelines.py
================================================
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from .utils import is_torch2_available
if is_torch2_available():
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
else:
from .attention_processor import IPAttnProcessor
class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
def set_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@torch.no_grad()
def __call__( # noqa: C901
self,
prompt: Optional[Union[str, List[str]]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
control_guidance_start: float = 0.0,
control_guidance_end: float = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a target image resolution. It should be as same
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
control_guidance_start (`float`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# get init conditioning scale
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
conditioning_scale = attn_processor.scale
break
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end):
self.set_scale(0.0)
else:
self.set_scale(conditioning_scale)
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
image = latents
if output_type != "latent":
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
================================================
FILE: ip_adapter/ip_adapter.py
================================================
import os
from typing import List
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from tutorial_train_sdxl_ori import HarmonyAttention
from .utils import is_torch2_available, get_generator
if is_torch2_available():
from .attention_processor import (
AttnProcessor2_0 as AttnProcessor,
)
from .attention_processor import (
CNAttnProcessor2_0 as CNAttnProcessor,
)
from .attention_processor import (
IPAttnProcessor2_0 as IPAttnProcessor,
)
else:
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .resampler import Resampler
class ImageProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim)
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=None, number_class_crossattention=None):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
# self.target_blocks = target_blocks or ["down_blocks.2.attentions.1"]
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.number_class_crossattention=number_class_crossattention.to(self.device, dtype=torch.float16)
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
def init_proj(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
if 'down_blocks.2.attentions.1' in name:
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
num_tokens=self.num_tokens, skip=False).to(self.device, dtype=torch.float16)
else:
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
num_tokens=self.num_tokens, skip=True).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if isinstance(self.pipe.controlnet, MultiControlNetModel):
for controlnet in self.pipe.controlnet.nets:
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
else:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj_model": {}, "ip_adapter": {}, 'composed_modules': {} }
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
# if 'unet.up_blocks' not in key and 'unet.mid_blocks' not in key:
# print(key)
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("adapter_modules."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
elif key.startswith("composed_modules."):
state_dict["composed_modules"][key.replace("composed_modules.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
print(state_dict.keys())
self.image_proj_model.load_state_dict(state_dict["image_proj"])
self.number_class_crossattention.load_state_dict(state_dict["composed_adapter"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None, extra_prompt_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
# add composer
if extra_prompt_embeds is not None:
extra_prompt_embeds = extra_prompt_embeds.to(self.device,torch.float16)
output= self.number_class_crossattention(extra_prompt_embeds, clip_image_embeds)
clip_image_embeds = clip_image_embeds + output
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
def generate(
self,
pil_image=None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
else:
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterXL(IPAdapter):
"""SDXL"""
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device,
num_tokens=4, target_blocks=None, inference=False, number_class_crossattention=None):
self.inference = inference
super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device,
num_tokens=num_tokens, target_blocks=target_blocks, number_class_crossattention=number_class_crossattention)
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
extra_text=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
if extra_text is not None:
with torch.inference_mode():
(
extra_prompt_embeds,
extra_negative_prompt_embeds,
extra_pooled_prompt_embeds,
extra_negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
extra_text,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, extra_prompt_embeds = extra_prompt_embeds)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
self.generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=self.generator,
**kwargs,
).images
return images
class IPAdapterPlus(IPAdapter):
def init_proj(self):
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter with full features"""
def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.hidden_size,
).to(self.device, dtype=torch.float16)
return image_proj_model
class IPAdapterPlusXL(IPAdapter):
"""SDXL"""
def init_proj(self):
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
================================================
FILE: ip_adapter/ip_adapter_origin.py
================================================
import os
from typing import List
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .utils import is_torch2_available, get_generator
if is_torch2_available():
from .attention_processor import (
AttnProcessor2_0 as AttnProcessor,
)
from .attention_processor import (
CNAttnProcessor2_0 as CNAttnProcessor,
)
from .attention_processor import (
IPAttnProcessor2_0 as IPAttnProcessor,
)
else:
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .resampler import Resampler
class ImageProjModel(torch.nn.Module):
"""投影模型 - 将CLIP图像特征转换为适合UNet交叉注意力的格式"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim # UNet交叉注意力的维度
self.clip_extra_context_tokens = clip_extra_context_tokens # 额外上下文token数量
# 线性投影层,将CLIP嵌入转换为多个额外的上下文token
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim) # 标准化层
def forward(self, image_embeds):
# 投影CLIP图像嵌入到多个上下文token
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""使用多层感知器的投影模型 - 用于IPAdapterFull变体"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
# 多层感知器,包含两个线性层、GELU激活和层标准化
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim)
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
def init_proj(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if isinstance(self.pipe.controlnet, MultiControlNetModel):
for controlnet in self.pipe.controlnet.nets:
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
else:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
def generate(
self,
pil_image=None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
else:
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterXL(IPAdapter):
"""SDXL"""
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
"""SDXL专用的生成方法,考虑了SDXL的pooled嵌入"""
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
self.generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=self.generator,
**kwargs,
).images
return images
class IPAdapterPlus(IPAdapter):
"""使用细粒度特征的IP-Adapter增强版本"""
def init_proj(self):
"""使用Resampler替代简单的线性投影"""
# Resampler是一种更强大的特征重采样模型,能更好地捕获图像中的细节
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
"""使用CLIP模型的倒数第二层隐藏状态作为更丰富的特征"""
# 使用CLIP的内部特征而非最终投影,提供更细粒度的图像理解
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter with full features"""
def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.hidden_size,
).to(self.device, dtype=torch.float16)
return image_proj_model
class IPAdapterPlusXL(IPAdapter):
"""SDXL"""
def init_proj(self):
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
================================================
FILE: ip_adapter/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
================================================
FILE: ip_adapter/shared_models.py
================================================
import torch
from torch import nn
import math
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch.nn as nn
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
import numpy as np
class Cross_Attention(nn.Module):
def __init__(self,
query_dim, # Q projection input dimension
context_dim, # K/V projection input dimension
heads=8,
head_dim=64,
value_dim=None, # V dimension after dimensionality reduction (defaults to head_dim)
out_dim=None): # Output dimension
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.scale = math.sqrt(head_dim)
self.value_dim = value_dim if value_dim is not None else head_dim
self.out_dim = out_dim if out_dim is not None else heads * self.value_dim
# Linear projection layers
self.to_q = nn.Linear(query_dim, heads * head_dim)
self.to_k = nn.Linear(context_dim, heads * head_dim)
self.to_v = nn.Linear(context_dim, heads * self.value_dim)
# Optional output projection
self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)
def forward(self, query_input, context_input):
"""
query_input: [B, Q_len, query_dim]
context_input: [B, K_len, context_dim]
"""
B = query_input.size(0)
# Project Q, K, V
q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, Q_len, head_dim]
k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, K_len, head_dim]
v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2) # [B, heads, K_len, v_dim]
# Attention weights
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [B, heads, Q_len, K_len]
attn_probs = F.softmax(attn_scores, dim=-1)
# Weighted sum
attn_output = torch.matmul(attn_probs, v) # [B, heads, Q_len, v_dim]
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim) # [B, Q_len, heads * v_dim]
# Output projection
output = self.out_proj(attn_output) # [B, Q_len, out_dim]
return output
class ImageProjModel(torch.nn.Module):
"""Projection model - converts CLIP image features into a format suitable for UNet cross-attention"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim # UNet cross-attention dimension
self.clip_extra_context_tokens = clip_extra_context_tokens # Number of extra context tokens
# Linear projection layer, converts CLIP embeddings into multiple extra context tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim) # Normalization layer
def forward(self, image_embeds):
# Project CLIP image embeddings into multiple context tokens
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention
def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):
super().__init__()
# Cross Attention layer
self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32)
self.scale=scale
# self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64)
# Image projection from 1280 to 2560
self.fc1=nn.Linear(hidden_size, hidden_size*2)
# Layer Normalization layer
self.ln = nn.LayerNorm(hidden_size)
# FC layer 1280 to 1280
self.fc2 = nn.Linear(hidden_size, hidden_size)
def forward(self, text_embeds,image_embeds):
image_embeds=self.fc1(image_embeds) #[1, 2560]
image_embeds=image_embeds.reshape(1,4,640)
output = self.cross_attention(image_embeds, text_embeds) #[1,4,320]
output=output.reshape(1,1280)
# Normalize the output using Layer Normalization
output=self.ln(output)
# FC layer [1,4,2048]->[1,4,2048]
output=self.fc2(output)
return output*self.scale
def load_from_checkpoint(self, ckpt_path: str):
from safetensors.torch import load_file
from collections import OrderedDict
# Load weights file
weights = load_file(ckpt_path)
# Initialize two dictionaries to store weights for different modules separately
image_proj_weights = OrderedDict()
attn_weights = OrderedDict()
# Separate weights into different modules
for k, v in weights.items():
# Process image_proj_model weights (match two possible key name prefixes)
if k.startswith("image_proj_model.") or k.startswith("image_proj."):
new_key = k.replace("image_proj_model.", "").replace("image_proj.", "")
if hasattr(self, "image_proj_model") and hasattr(self.image_proj_model, new_key.split('.')[0]):
image_proj_weights[new_key] = v
# Process target attention layer weights (match two possible key name formats)
elif "down_blocks.2.attentions.1" in k:
# Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1
new_key = k.replace("composed_modules.", "").replace("ip_adapter.", "")
if hasattr(self, new_key.split('.')[0]):
attn_weights[new_key] = v
# Load image_proj_model weights (strict mode)
if image_proj_weights:
self.image_proj_model.load_state_dict(image_proj_weights, strict=True)
print(f"Loaded image_proj_model weights: {len(image_proj_weights)} params")
# Load attention layer weights (non-strict mode)
if attn_weights:
# Create temporary ModuleDict to load weights
temp_dict = {k: v for k, v in self.named_modules()
if "down_blocks.2.attentions.1" in k}
temp_model = torch.nn.ModuleDict(temp_dict)
missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False)
if missing:
print(f"Missing keys in attention blocks: {missing}")
if unexpected:
print(f"Unexpected keys in attention blocks: {unexpected}")
print(f"Loaded attention weights: {len(attn_weights)} params")
print(f"Successfully loaded target modules from {ckpt_path}")
return self
================================================
FILE: ip_adapter/test_resampler.py
================================================
import torch
from resampler import Resampler
from transformers import CLIPVisionModel
BATCH_SIZE = 2
OUTPUT_DIM = 1280
NUM_QUERIES = 8
NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
def main():
image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
embedding_dim = image_encoder.config.hidden_size
print(f"image_encoder hidden size: ", embedding_dim)
image_proj_model = Resampler(
dim=1024,
depth=2,
dim_head=64,
heads=16,
num_queries=NUM_QUERIES,
embedding_dim=embedding_dim,
output_dim=OUTPUT_DIM,
ff_mult=2,
max_seq_len=257,
apply_pos_emb=APPLY_POS_EMB,
num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
)
dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
with torch.no_grad():
image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
print("image_embds shape: ", image_embeds.shape)
with torch.no_grad():
ip_tokens = image_proj_model(image_embeds)
print("ip_tokens shape:", ip_tokens.shape)
assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
if __name__ == "__main__":
main()
================================================
FILE: ip_adapter/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
attn_maps = {}
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if name.split('.')[-1].startswith('attn2'):
module.register_forward_hook(hook_fn(name))
return unet
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0)
attn_map = attn_map.permute(1,0)
temp_size = None
for i in range(0,5):
scale = 2 ** i
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
break
assert temp_size is not None, "temp_size cannot is None"
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
attn_map = F.interpolate(
attn_map.unsqueeze(0).to(dtype=torch.float32),
size=target_size,
mode='bilinear',
align_corners=False
)[0]
attn_map = torch.softmax(attn_map, dim=0)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
attn_map = upscale(attn_map, image_size)
net_attn_maps.append(attn_map)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
return net_attn_maps
def attnmaps2images(net_attn_maps):
#total_attn_scores = 0
images = []
for attn_map in net_attn_maps:
attn_map = attn_map.cpu().numpy()
#total_attn_scores += attn_map.mean().item()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
#print("norm: ", normalized_attn_map.shape)
image = Image.fromarray(normalized_attn_map)
#image = fix_save_attn_map(attn_map)
images.append(image)
#print(total_attn_scores)
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed, device):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(device).manual_seed(seed)
else:
generator = None
return generator
================================================
FILE: requirements.txt
================================================
absl-py==2.1.0
accelerate==1.0.1
aiofiles==23.2.1
aiohappyeyeballs==2.4.4
aiohttp==3.10.11
aiosignal==1.3.1
annotated-types==0.7.0
anyio==3.7.1
async-timeout==5.0.1
asyncer==0.0.2
attrs==25.3.0
bidict==0.23.1
bitsandbytes==0.45.5
cachetools==5.5.0
certifi==2024.8.30
chainlit==1.1.402
charset-normalizer==3.4.0
chevron==0.14.0
click==8.1.8
contourpy==1.1.1
cycler==0.12.1
dataclasses-json==0.5.14
datasets==3.1.0
Deprecated==1.2.18
diffusers==0.30.0
dill==0.3.8
distro==1.9.0
einops==0.8.0
exceptiongroup==1.2.2
fastapi==0.110.3
ffmpy==0.5.0
filelock==3.16.1
filetype==1.2.0
fonttools==4.57.0
frozenlist==1.5.0
fsspec==2024.9.0
ftfy==6.2.3
google-auth==2.36.0
google-auth-oauthlib==1.0.0
googleapis-common-protos==1.69.2
gradio==4.44.1
gradio_client==1.3.0
grpcio==1.67.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
huggingface-hub==0.25.2
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.4.5
Jinja2==3.1.4
jiter==0.9.0
kiwisolver==1.4.7
Lazify==0.4.0
literalai==0.0.607
loguru==0.7.3
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.22.0
matplotlib==3.7.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
openai==1.71.0
opentelemetry-api==1.31.1
opentelemetry-exporter-otlp==1.31.1
opentelemetry-exporter-otlp-proto-common==1.31.1
opentelemetry-exporter-otlp-proto-grpc==1.31.1
opentelemetry-exporter-otlp-proto-http==1.31.1
opentelemetry-instrumentation==0.52b1
opentelemetry-proto==1.31.1
opentelemetry-sdk==1.31.1
opentelemetry-semantic-conventions==0.52b1
orjson==3.10.15
packaging==23.2
pandas==2.0.3
peft==0.13.2
pillow==10.4.0
propcache==0.2.0
protobuf==5.28.3
psutil==6.1.0
pyarrow==17.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pydantic==2.10.6
pydantic_core==2.27.2
pydub==0.25.1
Pygments==2.19.1
PyJWT==2.9.0
pyparsing==3.1.4
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-engineio==4.11.2
python-multipart==0.0.9
python-socketio==5.12.1
pytz==2024.2
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
requests-oauthlib==2.0.0
rich==14.0.0
rsa==4.9
ruff==0.11.8
safetensors==0.4.5
scipy==1.10.1
semantic-version==2.10.0
shellingham==1.5.4
simple-websocket==1.1.0
six==1.17.0
sniffio==1.3.1
sse-starlette==2.1.3
starlette==0.37.2
sympy==1.13.3
syncer==2.0.3
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboardX==2.6.2.2
timm==1.0.13
tokenizers==0.20.3
tomli==2.2.1
tomlkit==0.12.0
torch==2.4.1
torchaudio==2.4.1
torchvision==0.19.1
tqdm==4.66.5
transformers==4.45.0
triton==3.0.0
typer==0.15.3
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
uptrace==1.31.0
urllib3==2.2.3
uvicorn==0.25.0
watchfiles==0.20.0
wcwidth==0.2.13
websockets==12.0
Werkzeug==3.0.6
wrapt==1.17.2
wsproto==1.2.0
xformers==0.0.28.post1
xxhash==3.5.0
yarl==1.15.2
zipp==3.20.2
================================================
FILE: run.sh
================================================
accelerate launch --gpu_ids 0 --num_processes 1 --mixed_precision "fp16" \
train.py \
--pretrained_model_name_or_path="your path" \
--pretrained_ip_adapter_path="your path" \
--image_encoder_path="your path" \
--data_root_path='your path' \
--mixed_precision="fp16" \
--resolution=512 \
--train_batch_size=1 \
--dataloader_num_workers=4 \
--learning_rate=2.5e-04 \
--data_json_file="your path" \
--weight_decay=0.01 \
--output_dir="your path" \
--save_steps=100 \
--num_train_epochs 2100 \
--composed_inter_dim=2560 \
--composed_cross_heads=8 \
--composed_reshape_blocks=8 \
--composed_cross_value_dim=64
================================================
FILE: sdxl-fine-tuning/data/train.json
================================================
[
{
"image_file": "your image",
"text": " ",
"extra_text": "your caption",
"comments": {
"image_file": "five_cats.png",
"text": "",
"extra_text": "Five cats"
}
}
]
================================================
FILE: shared_models.py
================================================
import torch
from torch import nn
import math
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch.nn as nn
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
import numpy as np
class Cross_Attention(nn.Module):
def __init__(self,
query_dim, # Q projection input dimension
context_dim, # K/V projection input dimension
heads=8,
head_dim=64,
value_dim=None, # V dimension after dimensionality reduction (defaults to head_dim)
out_dim=None): # Output dimension
super().__init__()
self.heads = heads
self.head_dim = head_dim
self.scale = math.sqrt(head_dim)
self.value_dim = value_dim if value_dim is not None else head_dim
self.out_dim = out_dim if out_dim is not None else heads * self.value_dim
# Linear projection layers
self.to_q = nn.Linear(query_dim, heads * head_dim)
self.to_k = nn.Linear(context_dim, heads * head_dim)
self.to_v = nn.Linear(context_dim, heads * self.value_dim)
# Optional output projection
self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)
def forward(self, query_input, context_input):
"""
query_input: [B, Q_len, query_dim]
context_input: [B, K_len, context_dim]
"""
B = query_input.size(0)
# Project Q, K, V
q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, Q_len, head_dim]
k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, K_len, head_dim]
v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2) # [B, heads, K_len, v_dim]
# Attention weights
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [B, heads, Q_len, K_len]
attn_probs = F.softmax(attn_scores, dim=-1)
# Weighted sum
attn_output = torch.matmul(attn_probs, v) # [B, heads, Q_len, v_dim]
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim) # [B, Q_len, heads * v_dim]
# Output projection
output = self.out_proj(attn_output) # [B, Q_len, out_dim]
return output
class ImageProjModel(torch.nn.Module):
"""Projection model - converts CLIP image features into a format suitable for UNet cross-attention"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim # UNet cross-attention dimension
self.clip_extra_context_tokens = clip_extra_context_tokens # Number of extra context tokens
# Linear projection layer, converts CLIP embeddings into multiple extra context tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim) # Normalization layer
def forward(self, image_embeds):
# Project CLIP image embeddings into multiple context tokens
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention
def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):
super().__init__()
# Cross Attention layer
self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32)
self.scale=scale
# self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64)
# Image projection from 1280 to 2560
self.fc1=nn.Linear(hidden_size, hidden_size*2)
# Layer Normalization layer
self.ln = nn.LayerNorm(hidden_size)
# FC layer 1280 to 1280
self.fc2 = nn.Linear(hidden_size, hidden_size)
def forward(self, text_embeds,image_embeds):
image_embeds=self.fc1(image_embeds) #[1, 2560]
image_embeds=image_embeds.reshape(1,4,640)
output = self.cross_attention(image_embeds, text_embeds) #[1,4,320]
output=output.reshape(1,1280)
# Normalize the output using Layer Normalization
output=self.ln(output)
# FC layer [1,4,2048]->[1,4,2048]
output=self.fc2(output)
return output*self.scale
def load_from_checkpoint(self, ckpt_path: str):
from safetensors.torch import load_file
from collections import OrderedDict
# Load weights file
weights = load_file(ckpt_path)
# Initialize two dictionaries to store weights for different modules separately
image_proj_weights = OrderedDict()
attn_weights = OrderedDict()
# Separate weights into different modules
for k, v in weights.items():
# Process image_proj_model weights (match two possible key name prefixes)
if k.startswith("image_proj_model.") or k.startswith("image_proj."):
new_key = k.replace("image_proj_model.", "").replace("image_proj.", "")
if hasattr(self, "image_proj_model") and hasattr(self.image_proj_model, new_key.split('.')[0]):
image_proj_weights[new_key] = v
# Process target attention layer weights (match two possible key name formats)
elif "down_blocks.2.attentions.1" in k:
# Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1
new_key = k.replace("composed_modules.", "").replace("ip_adapter.", "")
if hasattr(self, new_key.split('.')[0]):
attn_weights[new_key] = v
# Load image_proj_model weights (strict mode)
if image_proj_weights:
self.image_proj_model.load_state_dict(image_proj_weights, strict=True)
print(f"Loaded image_proj_model weights: {len(image_proj_weights)} params")
# Load attention layer weights (non-strict mode)
if attn_weights:
# Create temporary ModuleDict to load weights
temp_dict = {k: v for k, v in self.named_modules()
if "down_blocks.2.attentions.1" in k}
temp_model = torch.nn.ModuleDict(temp_dict)
missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False)
if missing:
print(f"Missing keys in attention blocks: {missing}")
if unexpected:
print(f"Unexpected keys in attention blocks: {unexpected}")
print(f"Loaded attention weights: {len(attn_weights)} params")
print(f"Successfully loaded target modules from {ckpt_path}")
return self
================================================
FILE: test.py
================================================
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from ip_adapter import IPAdapterXL # Assuming IPAdapterXL is correctly defined in ip_adapter
from train import HarmonyAttention # Assuming HarmonyAttention is correctly defined
import os
'''
The following parameters must be manually adjusted to match the training parameters.
'''
ckpt_inter_dim = 2560
ckpt_cross_heads = 8
ckpt_reshape_blocks = 8
ckpt_cross_value_dim = 64
# Image generation function
def generate_image(input_path, prompt, extra_text, output_path="output.png"):
print(f"\nGenerating image for: {prompt} + {extra_text}")
# Prepare input image
input_image = Image.open(input_path).resize((512, 512))
# Generate image
images = ip_model.generate(
pil_image=input_image,
prompt=prompt,
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
scale=1.0,
guidance_scale=5.0,
num_samples=1,
num_inference_steps=30,
seed=42,
extra_text=extra_text,
number_class_crossattention=number_class_crossattention # This should be the HarmonyAttention instance
)
# Save result
images[0].save(output_path)
print(f"Saved generated image to: {output_path}")
return images[0]
if __name__ == "__main__":
input_image = "your path to inputimage" # Replace with your actual input image path
device = "cuda:2"
# Model path configuration
base_model_path = "your path" # Replace with your actual base model path
image_encoder_path = "your path" # Replace with your actual image encoder path
fine_tuned_ckpt = "fine_tuned model path" # Path to the fine-tuned weights (ip_adapter.bin or similar)
save_root = os.path.join(
'your path', # Replace with your desired save directory
fine_tuned_ckpt.split('/')[3] if len(fine_tuned_ckpt.split('/')) > 3 else "default_folder1",
fine_tuned_ckpt.split('/')[4] if len(fine_tuned_ckpt.split('/')) > 4 else "default_folder2",
)
# Create path (if it doesn't exist)
os.makedirs(save_root, exist_ok=True)
print(f"Save directory created at: {save_root}")
# Load SDXL base model
print("Loading base SDXL model...")
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
)
pipe.enable_vae_tiling()
pipe.to(device)
# Define the fusion method to be used
fusion_method = "cross_attention" # Options: "cross_attention", "qformer", "mlp"
# Initialize HarmonyAttention (this is the `number_class_crossattention` module)
print("Initializing HarmonyAttention module...")
number_class_crossattention = HarmonyAttention( # Renamed for clarity, this is your custom attention module
image_hidden_size=1280, # Keep unchanged or match your training
text_context_dim=2048, # Keep unchanged or match your training
inter_dim=ckpt_inter_dim,
cross_heads=ckpt_cross_heads,
reshape_blocks=ckpt_reshape_blocks,
cross_value_dim=ckpt_cross_value_dim,
scale=1.0, # Keep unchanged or adjust as needed
fusion_method=fusion_method # Add fusion method selection
).to(device).half()
# Initialize IP-Adapter
print("Initializing IP-Adapter with target blocks...")
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
fine_tuned_ckpt,
device,
target_blocks=["down_blocks.2.attentions.1"], # Or your specific target blocks
num_tokens=4, # Or your specific number of tokens
inference=True,
number_class_crossattention=number_class_crossattention # Pass the initialized HarmonyAttention module here
)
print("HarmonyAttention weights are expected to be loaded as part of the IP-Adapter checkpoint.")
generate_image(
input_path=input_image,
prompt="lions",
extra_text="eight sheep", # Use the caption from training
output_path=os.path.join(save_root, os.path.basename(input_image)) # Safer way to construct output path
)
================================================
FILE: train.py
================================================
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch.nn as nn
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
from safetensors import safe_open
from shared_models import ImageProjModel
from baseline import QFormer,MLP,AttentionFusionWrapper
from safetensors.torch import save_file, load_file
from ip_adapter.utils import is_torch2_available
if is_torch2_available():
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
from ip_adapter.attention_processor import Cross_Attention
def count_model_params(model):
return sum([p.numel() for p in model.parameters()]) / 1e6
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
super().__init__()
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.size = size
self.center_crop = center_crop
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
self.transform = transforms.Compose([
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.clip_image_processor = CLIPImageProcessor()
def __getitem__(self, idx):
item = self.data[idx]
text = item["text"]
text_extra = item['extra_text']
image_file = item["image_file"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
# original size
original_width, original_height = raw_image.size
original_size = torch.tensor([original_height, original_width])
image_tensor = self.transform(raw_image.convert("RGB"))
# random crop
delta_h = image_tensor.shape[1] - self.size
delta_w = image_tensor.shape[2] - self.size
assert not all([delta_h, delta_w])
if self.center_crop:
top = delta_h // 2
left = delta_w // 2
else:
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
image = transforms.functional.crop(
image_tensor, top=top, left=left, height=self.size, width=self.size
)
crop_coords_top_left = torch.tensor([top, left])
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
# drop
drop_image_embed = 0
rand_num = random.random()
if rand_num < self.i_drop_rate:
drop_image_embed = 1
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
text = ""
drop_image_embed = 1
# get text and tokenize
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
text_input_ids_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
#extra text
text_extra_input_ids=self.tokenizer(
text_extra,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
text_extra_input_ids_2 = self.tokenizer_2(
text_extra,
max_length=self.tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return {
"image": image,
"text_input_ids": text_input_ids,
"text_input_ids_2": text_input_ids_2,
"text_extra_input_ids":text_extra_input_ids,
"text_extra_input_ids_2":text_extra_input_ids_2,
"clip_image": clip_image,
"drop_image_embed": drop_image_embed,
"original_size": original_size,
"crop_coords_top_left": crop_coords_top_left,
"target_size": torch.tensor([self.size, self.size]),
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
drop_image_embeds = [example["drop_image_embed"] for example in data]
original_size = torch.stack([example["original_size"] for example in data])
crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
target_size = torch.stack([example["target_size"] for example in data])
#extra text
text_extra_input_ids=torch.stack([example["text_extra_input_ids"] for example in data])
text_extra_input_ids_2=torch.stack([example["text_extra_input_ids_2"] for example in data])
return {
"images": images,
"text_input_ids": text_input_ids,
"text_input_ids_2": text_input_ids_2,
"clip_images": clip_images,
"drop_image_embeds": drop_image_embeds,
"original_size": original_size,
"crop_coords_top_left": crop_coords_top_left,
"target_size": target_size,
"text_extra_input_ids":text_extra_input_ids,
"text_extra_input_ids_2":text_extra_input_ids_2,
}
class HarmonyAttention(nn.Module):
def __init__(self,
image_hidden_size=1280, # Input image feature dimension
text_context_dim=2048, # Input text context feature dimension
inter_dim=2560, # Intermediate projection dimension
cross_heads=10, # Number of cross-attention heads
reshape_blocks=8, # Number of image feature blocks
cross_value_dim=64, # Value dimension per head after dimensionality reduction
scale=1.0, # Output scaling factor
fusion_method="qformer"): # Fusion method selection: mlp, cross_attention, qformer
super().__init__()
self.scale = scale
self.reshape_blocks = reshape_blocks
self.cross_query_dim = inter_dim // reshape_blocks
self.fusion_method = fusion_method
self.image_hidden_size = image_hidden_size
self.text_context_dim = text_context_dim
# Image projection required by all methods
self.fc1 = nn.Linear(image_hidden_size, inter_dim)
print(fusion_method)
if fusion_method == "cross_attention":
# 1. Cross-attention
self.fusion_text_image = Cross_Attention(
query_dim=self.cross_query_dim,
context_dim=text_context_dim,
heads=cross_heads,
value_dim=cross_value_dim
)
elif fusion_method == "qformer":
# 2. Q-Former TODO
self.fusion_text_image = QFormer(
num_queries=16,
hidden_dim=self.cross_query_dim,
num_layers=1,
num_heads=cross_heads
)
elif fusion_method=='mlp':
#3. MLP TODO
self.fusion_text_image = MLP(
fused_dim=self.cross_query_dim
)
elif fusion_method=='gated-attention':
#4. Gated-attention
self.fusion_text_image = AttentionFusionWrapper(
fused_dim=self.cross_query_dim
)
flattened_dim = cross_value_dim * cross_heads * reshape_blocks
self.ln = nn.LayerNorm(flattened_dim)
self.fc2 = nn.Linear(flattened_dim, image_hidden_size)
def forward(self, text_embeds, image_embeds):
"""
Args:
text_embeds: [B, T, text_context_dim]
image_embeds: [B, image_hidden_size]
Returns:
out: [B, image_hidden_size], image features fused with text conditions
"""
B = image_embeds.size(0)
# Map image features to intermediate dimension and reshape into multiple blocks
x = self.fc1(image_embeds) # [B, inter_dim]
x = x.view(B, self.reshape_blocks, self.cross_query_dim) # [B, N_blocks, query_dim]
# Cross-Attention: interaction between image blocks and text
print(self.fusion_text_image)
attended = self.fusion_text_image(x, text_embeds) # [B, N_blocks, value_dim * heads]
print(attended.shape)
# Flatten, normalize, and project back
attended = attended.view(B, -1) # [B, flattened_dim]
out = self.ln(attended)
out = self.fc2(out) * self.scale
return out
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None,
inter_dim=None,
cross_heads=None,
reshape_blocks=None,
cross_value_dim=None,
fusion_method="cross_attention"): # Fusion method parameter
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
# Directly use the fusion_method parameter, no need to convert to a boolean flag
self.composed_modules = HarmonyAttention(
image_hidden_size=1280, # Image feature dimension fixed
text_context_dim=2048, # Text context dimension fixed
inter_dim=inter_dim,
cross_heads=cross_heads,
reshape_blocks=reshape_blocks,
cross_value_dim=cross_value_dim,
scale=1.0, # Scaling factor fixed
fusion_method=fusion_method # Directly pass the fusion method parameter
)
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds, text_extra_embeds):
composed_embeds = self.composed_modules(text_extra_embeds, image_embeds)
image_embeds = image_embeds + composed_embeds
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
if os.path.splitext(ckpt_path)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--data_json_file",
type=str,
default=None,
required=True,
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default="",
required=True,
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images"
),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=10000)
parser.add_argument(
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=2000,
help=(
"Save a checkpoint of the training state every X updates"
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--composed_inter_dim",
type=int,
default=None,
help="HarmonyAttention's intermediate projection dimension [e.g., 1280, 2560]."
)
parser.add_argument(
"--composed_cross_heads",
type=int,
default=None,
help="Number of cross-attention heads in HarmonyAttention [e.g., 8, 10]."
)
parser.add_argument(
"--composed_reshape_blocks",
type=int,
default=None,
help="Number of image feature blocks in HarmonyAttention [e.g., 4, 8]."
)
parser.add_argument(
"--composed_cross_value_dim",
type=int,
default=None,
help="Value dimension per head after dimensionality reduction in HarmonyAttention [e.g., 32, 64]."
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
# Parse command-line arguments
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
# Configure accelerate for mixed-precision and distributed training
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Create output directory
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load various pre-trained model components
# Noise addition control, tokenizer, text encoder, VAE, UNet, etc.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# Freeze base model parameters, only train the adapter part
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
# Initialize IP-Adapter model components
# Image projection model maps CLIP image embeddings to UNet's cross-attention dimension
num_tokens = 4 # Number of extra context tokens
image_proj_model = ImageProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
clip_embeddings_dim=image_encoder.config.projection_dim,
clip_extra_context_tokens=num_tokens,
)
# Initialize adapter attention processors
# Create appropriate attention processors for each attention block in UNet
# init adapter modules
attn_procs = {}
unet_sd = unet.state_dict()
# Initialize UNet's attention + IP-Adapter's cross_attention parameters
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
# Layers that need to add IP join additional attention parameters
if 'down_blocks.2.attentions.1' in name:
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
num_tokens=num_tokens, skip=False)
attn_procs[name].load_state_dict(weights, strict=False)
else:
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
num_tokens=num_tokens, skip=True)
# Load all attention processing into UNet
unet.set_attn_processor(attn_procs)
# Extract all attention layers into a list
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
# Create the complete IP-Adapter model
ip_adapter = IPAdapter(
unet,
image_proj_model,
adapter_modules,
args.pretrained_ip_adapter_path,
inter_dim=args.composed_inter_dim, # Use command-line arguments
cross_heads=args.composed_cross_heads, # Use command-line arguments
reshape_blocks=args.composed_reshape_blocks, # Use command-line arguments
cross_value_dim=args.composed_cross_value_dim, # Use command-line arguments
fusion_method='cross_attention'
)
# ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)
# Set data type for mixed-precision training
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move models to the appropriate device and convert to the appropriate data type
vae.to(accelerator.device) # VAE uses fp32 for better stability
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
image_encoder.to(accelerator.device, dtype=weight_dtype)
# Set optimizer - only optimize IP-Adapter components
params_to_opt = itertools.chain(ip_adapter.adapter_modules.parameters(), ip_adapter.composed_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
accelerator.print("Trainable parameters: adapter_modules:{:.2f}M, composed_modules:{:.2f}M".format(
count_model_params(ip_adapter.adapter_modules),
count_model_params(ip_adapter.composed_modules)))
# Create dataset and dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Prepare model, optimizer, and dataloader with accelerator
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
# Start training loop
global_step = 0
for epoch in range(0, args.num_train_epochs):
begin = time.perf_counter()
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space using VAE
with torch.no_grad():
# SDXL's VAE uses fp32 for better numerical stability
latents = vae.encode(batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents.to(accelerator.device, dtype=weight_dtype)
# Generate random noise to add to the latent representation
noise = torch.randn_like(latents)
if args.noise_offset:
# Use noise offset technique to improve training stability
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(accelerator.device, dtype=weight_dtype)
bsz = latents.shape[0]
# Sample random timesteps for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise according to timesteps (forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get CLIP image embeddings
with torch.no_grad():
image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
# Apply image embedding dropout strategy
image_embeds_ = []
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
if drop_image_embed == 1:
image_embeds_.append(torch.zeros_like(image_embed))
else:
image_embeds_.append(image_embed)
image_embeds = torch.stack(image_embeds_)
# Get text embeddings (SDXL uses two text encoders)
with torch.no_grad():
encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True)
text_embeds = encoder_output.hidden_states[-2]
encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True)
pooled_text_embeds = encoder_output_2[0]
text_embeds_2 = encoder_output_2.hidden_states[-2]
text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # Concatenate outputs of the two text encoders
#extra text
encoder_extra_output = text_encoder(batch['text_extra_input_ids'].to(accelerator.device),output_hidden_states=True)
text_extra_embeds = encoder_extra_output.hidden_states[-2]
encoder_extra_output_2 = text_encoder_2(batch['text_extra_input_ids_2'].to(accelerator.device),output_hidden_states=True)
text_extra_embeds_2 = encoder_extra_output_2.hidden_states[-2]
text_extra_embeds=torch.concat([text_extra_embeds,text_extra_embeds_2],dim=-1)
# Add extra conditions required by SDXL (image size and crop info)
add_time_ids = [
batch["original_size"].to(accelerator.device),
batch["crop_coords_top_left"].to(accelerator.device),
batch["target_size"].to(accelerator.device),
]
add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
# Predict noise using IP-Adapter
noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds, text_extra_embeds)
# Calculate MSE loss (between predicted noise and actual added noise)
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather loss in distributed training
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
# Backpropagation and optimizer step
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Print training information
if accelerator.is_main_process:
print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
global_step += 1
# Save checkpoint periodically
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path, safe_serialization=False)
begin = time.perf_counter()
if __name__ == "__main__":
main()