Repository: muzishen/IMAGHarmony Branch: main Commit: 46895e751ff6 Files: 20 Total size: 175.2 KB Directory structure: gitextract_ymnget5p/ ├── LICENSE ├── README.md ├── baseline.py ├── convert_bin.py ├── demo.py ├── ip_adapter/ │ ├── __init__.py │ ├── attention_processor.py │ ├── custom_pipelines.py │ ├── ip_adapter.py │ ├── ip_adapter_origin.py │ ├── resampler.py │ ├── shared_models.py │ ├── test_resampler.py │ └── utils.py ├── requirements.txt ├── run.sh ├── sdxl-fine-tuning/ │ └── data/ │ └── train.json ├── shared_models.py ├── test.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout ## 🗓️ Release - [2025/5/30] 🔥 We released the [technical report](https://arxiv.org/pdf/2506.01949) of IMAGHarmony. - [2025/5/28] 🔥 We release the train and inference code of IMAGHarmony. - [2025/5/17] 🎉 We launch the [project page](https://revive234.github.io/IMAGHarmony.github.io/) of IMAGHarmony. ## 💡 Introduction IMAGHarmony tackles the challenge of controllable image editing in multi-object scenes, where existing models struggle with aligning object quantity and spatial layout. To this end, IMAGHarmony introduces a structure-aware framework for quantity-and-layout consistent image editing (QL-Edit), enabling precise control over object count, category, and arrangement. We propose a harmony aware (HA) mudule to jointly model object structure and semantics, and a preference-guided noise selection (PNS) strategy to stabilize generation by selecting semantically aligned initial noise. Our method is trained and evaluated on HarmonyBench, a newly curated benchmark with diverse editing scenarios. ![architecture](./assets/1.png) ## 🚀 HarmonyBench Dataset Demo ![dataset_demo](./assets/harmonybench.jpg) ## 🚀 Examples ![results_1](./assets/sotacomp.jpg) ![results_2](./assets/multi.jpg) ### Dual-Category Editing ![results_5](./assets/3edit.jpg) ## 🔧 Requirements - Python>=3.8 - [PyTorch>=2.0.0](https://pytorch.org/) - cuda>=11.8 ``` conda create --name IMAGHarmony python=3.8.18 conda activate IMAGHarmony # Install requirements pip install -r requirements.txt ``` ## 🌐 Download Models You can download our models from [Huggingface](https://huggingface.co/kkkkggg/IMAGHarmony). You can download the other component models from the original repository, as follows. - [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) - [stable-diffusion-XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) ## 🚀 How to train ``` # Please download the HarmonyBench data first or prepare your own images # and modify the path in run.sh ## Write caption of your image in your train.json file # start training sh train.sh ``` ## 🚀 How to test ``` #Please convert your checkpionts python conver_bin.py #Please fill in your path in test.py #then run python test.py ``` Or you may like to test it on gradio ``` python demo.py ``` ## Acknowledgement We would like to thank the contributors to the [Instantstyle](https://github.com/instantX-research/InstantStyle) and [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) repositories, for their open research and exploration. The IMAGHarmony code is available for both academic and commercial use. Users are permitted to generate images using this tool, provided they comply with local laws and exercise responsible use. The developers disclaim all liability for any misuse or unlawful activity by users. ## Citation If you find IMAGHarmony useful for your research and applications, please cite using this BibTeX: ```bibtex @misc{shen2025imagharmonycontrollableimageediting, title={IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout}, author={Fei Shen and Yutong Gao and Jian Yu and Xiaoyu Du and Jinhui Tang}, year={2025}, eprint={2506.01949}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2506.01949}, } ``` ## 🕒 TODO List - [x] Paper - [x] Train Code - [x] Inference Code - [ ] HarmonyBench Dataset - [ ] Model Weights ## 👉 **Our other projects:** - [IMAGEdit](https://github.com/XWH-A/IMAGEdit): Training-Free Controllable Video Editing with Consistent Object Layout. [可控多目标视频编辑] - [IMAGDressing](https://github.com/muzishen/IMAGDressing): Controllable dressing generation. [可控穿衣生成] - [IMAGGarment](https://github.com/muzishen/IMAGGarment): Fine-grained controllable garment generation. [可控服装生成] - [IMAGHarmony](https://github.com/muzishen/IMAGHarmony): Controllable image editing with consistent object layout. [可控多目标图像编辑] - [IMAGPose](https://github.com/muzishen/IMAGPose): Pose-guided person generation with high fidelity. [可控多模式人物生成] - [RCDMs](https://github.com/muzishen/RCDMs): Rich-contextual conditional diffusion for story visualization. [可控故事生成] - [PCDMs](https://github.com/tencent-ailab/PCDMs): Progressive conditional diffusion for pose-guided image synthesis. [可控人物生成] - [V-Express](https://github.com/tencent-ailab/V-Express/): Explores strong and weak conditional relationships for portrait video generation. [可控数字人生成] - [FaceShot](https://github.com/open-mmlab/FaceShot/): Talkingface plugin for any character. [可控动漫数字人生成] - [CharacterShot](https://github.com/Jeoyal/CharacterShot): Controllable and consistent 4D character animation framework. [可控4D角色生成] - [StyleTailor](https://github.com/mahb-THU/StyleTailor): An Agent for personalized fashion styling. [个性化时尚Agent] - [SignVip](https://github.com/umnooob/signvip/): Controllable sign language video generation. [可控手语生成] ## 📨 Contact If you have any questions, please feel free to contact with us at shenfei140721@126.com and yutonggaokkk@njust.edu.cn. ================================================ FILE: baseline.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class QFormer(nn.Module): def __init__(self, hidden_dim=768, num_queries=16, num_layers=6, num_heads=12, image_feat_dim=320, text_feat_dim=2048, add_modality_embedding=True): super(QFormer, self).__init__() self.hidden_dim = hidden_dim self.num_queries = num_queries self.add_modality_embedding = add_modality_embedding # Learnable query tokens: [1, num_queries, D] self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim)) # Modality type embeddings (0=image, 1=text) if add_modality_embedding: self.modality_embed = nn.Embedding(2, hidden_dim) self.image_proj = nn.Linear(image_feat_dim, hidden_dim) self.text_proj = nn.Linear(text_feat_dim, hidden_dim) # Transformer encoder: Q interacts with K/V (image + text) encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) def forward(self, image_feat, text_feat): """ image_feat: [B, T_img, D] text_feat: [B, T_txt, D] """ B = image_feat.size(0) print(image_feat.shape,text_feat.shape) image_feat = self.image_proj(image_feat) text_feat = self.text_proj(text_feat) # Concatenate image + text features as K/V kv = torch.cat([image_feat, text_feat], dim=1) # [B, T_img + T_txt, D] # Add modality type embedding if self.add_modality_embedding: T_img = image_feat.size(1) T_txt = text_feat.size(1) modality_ids = torch.cat([ torch.zeros(T_img, dtype=torch.long), torch.ones(T_txt, dtype=torch.long) ], dim=0).to(image_feat.device) # [T_img + T_txt] modality_embed = self.modality_embed(modality_ids) # [T_img + T_txt, D] kv = kv + modality_embed.unsqueeze(0) # broadcast to [B, T, D] # Expand learnable query tokens to batch size queries = self.query_tokens.expand(B, -1, -1) # [B, N_query, D] # Q-Former: let queries attend to K/V # Transformer requires concat(Q, K/V) input_seq = torch.cat([queries, kv], dim=1) # [B, N_query + T, D] output = self.transformer(input_seq) # [B, N_query + T, D] # Return only the updated query tokens return output[:, :self.num_queries, :] # [B, N_query, D] class MLP(nn.Module): def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header =16): super().__init__() self.image_proj = nn.Linear(image_dim, fused_dim) self.text_proj = nn.Linear(text_dim, fused_dim) self.num_header = num_header self.mlp = nn.Sequential( nn.Linear(2 * fused_dim, fused_dim), nn.ReLU(), nn.Linear(fused_dim, fused_dim), nn.ReLU(), nn.Linear(fused_dim,fused_dim*16) ) self.fused_dim=fused_dim def forward(self, image_feat, text_feat): """ image_feat: [B, T_img, image_dim] text_feat: [B, T_txt, text_dim] """ image_repr = image_feat.mean(dim=1) # [B, image_dim] text_repr = text_feat.mean(dim=1) # [B, text_dim] image_proj = self.image_proj(image_repr) # [B, fused_dim] text_proj = self.text_proj(text_repr) # [B, fused_dim] fused = torch.cat([image_proj, text_proj], dim=-1) # [B, 2*fused_dim] output = self.mlp(fused).reshape(-1,self.num_header,self.fused_dim) # [B, fused_dim] return output class GatedAttentionFusion(nn.Module): def __init__(self, input_dim=768, hidden_dim=512): super().__init__() self.gate_mlp = nn.Sequential( nn.Linear(2 * input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() # alpha ∈ [0, 1] ) def forward(self, img_feat, txt_feat): """ img_feat: [B, D] txt_feat: [B, D] """ fused_input = torch.cat([img_feat, txt_feat], dim=-1) # [B, 2D] alpha = self.gate_mlp(fused_input) # [B, 1] fused = alpha * img_feat + (1 - alpha) * txt_feat # [B, D] return fused class AttentionFusionWrapper(nn.Module): def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header=16): super().__init__() self.img_proj = nn.Linear(image_dim, fused_dim) self.txt_proj = nn.Linear(text_dim, fused_dim) self.fusion = GatedAttentionFusion(input_dim=fused_dim) self.num_header = num_header self.fused_dim = fused_dim self.dim_transfer = nn.Linear(fused_dim,fused_dim*self.num_header) def forward(self, image_feat, text_feat): """ image_feat: [B, T_img, 320] text_feat: [B, T_txt, 2048] """ # Mean pooling img_global = image_feat.mean(dim=1) # [B, 320] txt_global = text_feat.mean(dim=1) # [B, 2048] # Linear projection img_proj = self.img_proj(img_global) # [B, 768] txt_proj = self.txt_proj(txt_global) # [B, 768] # Gated fusion fused = self.fusion(img_proj, txt_proj) # [B, 768] fused = self.dim_transfer(fused).reshape(-1,self.num_header,self.fused_dim) return fused ================================================ FILE: convert_bin.py ================================================ import os import torch from collections import OrderedDict def convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path): if not os.path.exists(pytorch_model_path): print(f" [Warning] Source file not found, skipping: {pytorch_model_path}") return False print(f" Converting: {pytorch_model_path}") try: sd = torch.load(pytorch_model_path, map_location="cpu") image_proj_sd = OrderedDict() ip_sd = OrderedDict() composed_sd = OrderedDict() for k in sd: if k.startswith("image_proj_model."): image_proj_sd[k.replace("image_proj_model.", "")] = sd[k] elif k.startswith("adapter_modules."): ip_sd[k.replace("adapter_modules.", "")] = sd[k] elif k.startswith("composed_modules."): composed_sd[k.replace("composed_modules.", "")] = sd[k] if not image_proj_sd and not ip_sd and not composed_sd: print(f" [Warning] No expected keys (image_proj_model, adapter_modules, composed_modules) found in {pytorch_model_path}. Skipping save.") return False final_sd = { "image_proj": image_proj_sd, "ip_adapter": ip_sd, 'composed_adapter': composed_sd } torch.save(final_sd, output_ip_adapter_path) print(f" Successfully saved: {output_ip_adapter_path}") return True except Exception as e: print(f" [Error] Failed to convert {pytorch_model_path}: {e}") return False if __name__ == "__main__": base_log_dir = "your fine_tuned model path" total_converted = 0 total_skipped = 0 total_errors = 0 print(f"Starting conversion process in base directory: {base_log_dir}") for training_run_dir_name in os.listdir(base_log_dir): training_run_dir_path = os.path.join(base_log_dir, training_run_dir_name) # Check if it's actually a directory if os.path.isdir(training_run_dir_path): print(f"\nProcessing training run: {training_run_dir_name}") # Iterate through items inside the training run directory for checkpoint_dir_name in os.listdir(training_run_dir_path): if checkpoint_dir_name.startswith("checkpoint-") and \ os.path.isdir(os.path.join(training_run_dir_path, checkpoint_dir_name)): checkpoint_dir_path = os.path.join(training_run_dir_path, checkpoint_dir_name) print(f"- Found checkpoint directory: {checkpoint_dir_name}") pytorch_model_path = os.path.join(checkpoint_dir_path, "pytorch_model.bin") output_ip_adapter_path = os.path.join(checkpoint_dir_path, "ip_adapter.bin") if os.path.exists(output_ip_adapter_path): print(f" Output file already exists, skipping: {output_ip_adapter_path}") total_skipped += 1 continue success = convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path) if success: total_converted += 1 else: if not os.path.exists(pytorch_model_path): total_skipped += 1 else: total_errors +=1 print("\n--- Conversion Summary ---") print(f"Total checkpoints converted: {total_converted}") print(f"Total checkpoints skipped (e.g., source missing): {total_skipped}") print(f"Total errors during conversion: {total_errors}") ================================================ FILE: demo.py ================================================ import gradio as gr import torch from diffusers import StableDiffusionXLPipeline from PIL import Image from ip_adapter import IPAdapterXL from huggingface_hub import hf_hub_download import os import time try: from tutorial_train_sdxl_ori import ComposedAttention except ImportError: print("Error: Could not import ComposedAttention.") print("Please ensure 'tutorial_train_sdxl_ori.py' is in the same directory as this script.") exit() print("Loading models, please wait...") CKPT_INTER_DIM = 2560 CKPT_CROSS_HEADS = 8 CKPT_RESHAPE_BLOCKS = 8 CKPT_CROSS_VALUE_DIM = 64 BASE_MODEL_PATH = "/aigc_data_hdd/checkpoints/stable-diffusion-xl-base-1.0" IMAGE_ENCODER_PATH = os.path.join(BASE_MODEL_PATH, "image_encoder") if not os.path.exists(BASE_MODEL_PATH) or not os.path.exists(IMAGE_ENCODER_PATH): print(f"Error: Model or image encoder path not found: {BASE_MODEL_PATH}") exit() IP_ADAPTER_REPO_ID = "kkkkggg/IMAGHarmony" IP_ADAPTER_FILENAME = "IMAGHarmony_variant1.bin" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Downloading weights file {IP_ADAPTER_FILENAME} from repository {IP_ADAPTER_REPO_ID}...") try: fine_tuned_ckpt_path = hf_hub_download( repo_id=IP_ADAPTER_REPO_ID, filename=IP_ADAPTER_FILENAME, ) print(f"Weights file downloaded to: {fine_tuned_ckpt_path}") except Exception as e: print(f"Failed to download weights: {e}") exit() print(f"Loading Stable Diffusion XL pipeline from local path: {BASE_MODEL_PATH}") pipe = StableDiffusionXLPipeline.from_pretrained( BASE_MODEL_PATH, torch_dtype=torch.float16, add_watermarker=False, ).to(DEVICE) pipe.enable_vae_tiling() print("Instantiating custom ComposedAttention module...") number_class_crossattention = ComposedAttention( image_hidden_size=1280, text_context_dim=2048, inter_dim=CKPT_INTER_DIM, cross_heads=CKPT_CROSS_HEADS, reshape_blocks=CKPT_RESHAPE_BLOCKS, cross_value_dim=CKPT_CROSS_VALUE_DIM, scale=1.0 ).to(DEVICE).half() print("Extracting and loading ComposedAttention weights from the main checkpoint...") try: state_dict = torch.load(fine_tuned_ckpt_path, map_location="cpu") composed_attention_weights = state_dict["composed_adapter"] number_class_crossattention.load_state_dict(composed_attention_weights) print("Successfully loaded fine-tuned ComposedAttention weights.") except KeyError: print(f"Error: Key 'composed_adapter' not found in weights file {IP_ADAPTER_FILENAME}.") exit() except Exception as e: print(f"An unknown error occurred while loading ComposedAttention weights: {e}") exit() print("Initializing IP-Adapter...") ip_model = IPAdapterXL( pipe, IMAGE_ENCODER_PATH, fine_tuned_ckpt_path, DEVICE, target_blocks=["down_blocks.1.attentions.1"], num_tokens=4, inference=True, number_class_crossattention=number_class_crossattention ) print("Models loaded. Gradio application is ready!") def generate_image(uploaded_image: Image.Image, local_path: str, save_path: str, prompt: str, extra_text: str, negative_prompt: str, guidance_scale: float, num_inference_steps: int, seed: int, progress=gr.Progress()): pil_image = None if uploaded_image is not None: pil_image = uploaded_image elif local_path and local_path.strip(): try: pil_image = Image.open(local_path.strip()) except FileNotFoundError: raise gr.Error(f"File not found. Please check the path: {local_path.strip()}") except Exception as e: raise gr.Error(f"Cannot open image file. Error: {e}") else: raise gr.Error("Please upload a reference image or provide a valid local file path!") input_image = pil_image.resize((512, 512)) progress(0, desc="Generating image...") images = ip_model.generate( pil_image=input_image, prompt=prompt, negative_prompt=negative_prompt, scale=1.0, guidance_scale=guidance_scale, num_samples=1, num_inference_steps=int(num_inference_steps), seed=int(seed), extra_text=extra_text, number_class_crossattention=number_class_crossattention ) generated_image = images[0] progress(1, desc="Generation complete!") if save_path and save_path.strip(): try: save_dir = save_path.strip() os.makedirs(save_dir, exist_ok=True) timestamp = time.strftime("%Y%m%d-%H%M%S") filename = f"output_{timestamp}_seed{seed}.png" full_path = os.path.join(save_dir, filename) generated_image.save(full_path) gr.Info(f"Image successfully saved to: {full_path}") except Exception as e: gr.Warning(f"Could not save the image! Error: {e}") print(f"Error saving image: {e}") return generated_image with gr.Blocks() as demo: gr.Markdown("# IMAGHarmony: Image Generation Demo") gr.Markdown( "**Upload a reference image from your computer, or enter the full local path in the text box below.**\n" "Then, enter a **Target Prompt** and a **Reference Content** description to generate a new image." ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Your Reference Image") local_path_input = gr.Textbox( label="Or Enter Local Image Path", placeholder="/home/user/images/photo.jpg", info="If an image is uploaded, it will be prioritized over the path." ) prompt = gr.Textbox(label="Target Prompt", value="four cats") extra_text = gr.Textbox( label="Reference Content", info="Enter text that describes the reference image, typically the caption used during training.", value="four dogs" ) neg_prompt = gr.Textbox( label="Negative Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry" ) save_path_input = gr.Textbox( label="Save to Local Directory (Optional)", placeholder="/your/path", info="If left empty, the image will not be saved." ) run_button = gr.Button("Generate Image", variant="primary") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Generated Image") with gr.Accordion("Advanced Settings", open=False): guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=10.0, label="Guidance Scale") num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps") seed = gr.Slider(minimum=0, maximum=999999, step=1, value=8, label="Seed", randomize=True) run_button.click( fn=generate_image, inputs=[input_image, local_path_input, save_path_input, prompt, extra_text, neg_prompt, guidance_scale, num_inference_steps, seed], outputs=output_image ) demo.launch(share=True) ================================================ FILE: ip_adapter/__init__.py ================================================ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull __all__ = [ "IPAdapter", "IPAdapterPlus", "IPAdapterPlusXL", "IPAdapterXL", "IPAdapterFull", ] ================================================ FILE: ip_adapter/attention_processor.py ================================================ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import torch import torch.nn as nn import torch.nn.functional as F import math import torch import torch.nn as nn import torch.nn.functional as F import math class Cross_Attention(nn.Module): def __init__(self, query_dim, # Input dimension for Q projection context_dim, # Input dimension for K/V projection heads=8, value_dim=None, # Dimension of V after projection (defaults to head_dim) out_dim=None): # Output dimension super().__init__() self.query_dim = query_dim self.heads = heads self.head_dim = self.query_dim // self.heads self.scale = math.sqrt(self.head_dim) self.value_dim = value_dim if value_dim is not None else self.head_dim self.out_dim = out_dim if out_dim is not None else heads * self.value_dim # Linear projection layers self.to_q = nn.Linear(query_dim, self.heads * self.head_dim) self.to_k = nn.Linear(context_dim, self.heads * self.head_dim) self.to_v = nn.Linear(context_dim, self.heads * self.value_dim) # Optional output projection self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim) def forward(self, query_input, context_input): B = query_input.size(0) # Project Q, K, V q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, Q_len, head_dim] k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, K_len, head_dim] v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2) # [B, heads, K_len, v_dim] # Attention scores (weights) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [B, heads, Q_len, K_len] attn_probs = F.softmax(attn_scores, dim=-1) # Weighted sum attn_output = torch.matmul(attn_probs, v) # [B, heads, Q_len, v_dim] # Concatenate heads attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim) # [B, Q_len, heads * v_dim] # Output projection output = self.out_proj(attn_output) # [B, Q_len, out_dim] return output class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.skip = skip self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if not self.skip: # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) self.attn_map = ip_attention_probs ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.skip = skip self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if not self.skip: # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) with torch.no_grad(): self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) #print(self.attn_map.shape) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states ## for controlnet class CNAttnProcessor: r""" Default processor for performing attention-related computations. """ def __init__(self, num_tokens=4): self.num_tokens = num_tokens def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CNAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self, num_tokens=4): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.num_tokens = num_tokens def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states ================================================ FILE: ip_adapter/custom_pipelines.py ================================================ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from diffusers import StableDiffusionXLPipeline from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from .utils import is_torch2_available if is_torch2_available(): from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor else: from .attention_processor import IPAttnProcessor class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): def set_scale(self, scale): for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @torch.no_grad() def __call__( # noqa: C901 self, prompt: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. control_guidance_start (`float`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 7.1 Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # get init conditioning scale for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): conditioning_scale = attn_processor.scale break with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end): self.set_scale(0.0) else: self.set_scale(conditioning_scale) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if output_type != "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: ip_adapter/ip_adapter.py ================================================ import os from typing import List import torch from diffusers import StableDiffusionPipeline from diffusers.pipelines.controlnet import MultiControlNetModel from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from tutorial_train_sdxl_ori import HarmonyAttention from .utils import is_torch2_available, get_generator if is_torch2_available(): from .attention_processor import ( AttnProcessor2_0 as AttnProcessor, ) from .attention_processor import ( CNAttnProcessor2_0 as CNAttnProcessor, ) from .attention_processor import ( IPAttnProcessor2_0 as IPAttnProcessor, ) else: from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor from .resampler import Resampler class ImageProjModel(torch.nn.Module): def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class MLPProjModel(torch.nn.Module): def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): super().__init__() self.proj = torch.nn.Sequential( torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), torch.nn.GELU(), torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), torch.nn.LayerNorm(cross_attention_dim) ) def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class IPAdapter: def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=None, number_class_crossattention=None): self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens # self.target_blocks = target_blocks or ["down_blocks.2.attentions.1"] self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( self.device, dtype=torch.float16 ) self.clip_image_processor = CLIPImageProcessor() self.number_class_crossattention=number_class_crossattention.to(self.device, dtype=torch.float16) # image proj model self.image_proj_model = self.init_proj() self.load_ip_adapter() def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_extra_context_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) return image_proj_model def set_ip_adapter(self): unet = self.pipe.unet attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: if 'down_blocks.2.attentions.1' in name: attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens, skip=False).to(self.device, dtype=torch.float16) else: attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens, skip=True).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) if hasattr(self.pipe, "controlnet"): if isinstance(self.pipe.controlnet, MultiControlNetModel): for controlnet in self.pipe.controlnet.nets: controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) else: self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) def load_ip_adapter(self): if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": state_dict = {"image_proj_model": {}, "ip_adapter": {}, 'composed_modules': {} } with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: for key in f.keys(): # if 'unet.up_blocks' not in key and 'unet.mid_blocks' not in key: # print(key) if key.startswith("image_proj."): state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) elif key.startswith("adapter_modules."): state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) elif key.startswith("composed_modules."): state_dict["composed_modules"][key.replace("composed_modules.", "")] = f.get_tensor(key) else: state_dict = torch.load(self.ip_ckpt, map_location="cpu") print(state_dict.keys()) self.image_proj_model.load_state_dict(state_dict["image_proj"]) self.number_class_crossattention.load_state_dict(state_dict["composed_adapter"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None, extra_prompt_embeds=None): if pil_image is not None: if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds else: clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) # add composer if extra_prompt_embeds is not None: extra_prompt_embeds = extra_prompt_embeds.to(self.device,torch.float16) output= self.number_class_crossattention(extra_prompt_embeds, clip_image_embeds) clip_image_embeds = clip_image_embeds + output image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale def generate( self, pil_image=None, clip_image_embeds=None, prompt=None, negative_prompt=None, scale=1.0, num_samples=4, seed=None, guidance_scale=7.5, num_inference_steps=30, **kwargs, ): self.set_scale(scale) if pil_image is not None: num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) else: num_prompts = clip_image_embeds.size(0) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( pil_image=pil_image, clip_image_embeds=clip_image_embeds ) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, **kwargs, ).images return images class IPAdapterXL(IPAdapter): """SDXL""" def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=None, inference=False, number_class_crossattention=None): self.inference = inference super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=num_tokens, target_blocks=target_blocks, number_class_crossattention=number_class_crossattention) def generate( self, pil_image, prompt=None, negative_prompt=None, extra_text=None, scale=1.0, num_samples=4, seed=None, num_inference_steps=30, **kwargs, ): self.set_scale(scale) num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts if extra_text is not None: with torch.inference_mode(): ( extra_prompt_embeds, extra_negative_prompt_embeds, extra_pooled_prompt_embeds, extra_negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( extra_text, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, extra_prompt_embeds = extra_prompt_embeds) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) self.generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, generator=self.generator, **kwargs, ).images return images class IPAdapterPlus(IPAdapter): def init_proj(self): image_proj_model = Resampler( dim=self.pipe.unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds class IPAdapterFull(IPAdapterPlus): """IP-Adapter with full features""" def init_proj(self): image_proj_model = MLPProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, clip_embeddings_dim=self.image_encoder.config.hidden_size, ).to(self.device, dtype=torch.float16) return image_proj_model class IPAdapterPlusXL(IPAdapter): """SDXL""" def init_proj(self): image_proj_model = Resampler( dim=1280, depth=4, dim_head=64, heads=20, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds def generate( self, pil_image, prompt=None, negative_prompt=None, scale=1.0, num_samples=4, seed=None, num_inference_steps=30, **kwargs, ): self.set_scale(scale) num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, generator=generator, **kwargs, ).images return images ================================================ FILE: ip_adapter/ip_adapter_origin.py ================================================ import os from typing import List import torch from diffusers import StableDiffusionPipeline from diffusers.pipelines.controlnet import MultiControlNetModel from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from .utils import is_torch2_available, get_generator if is_torch2_available(): from .attention_processor import ( AttnProcessor2_0 as AttnProcessor, ) from .attention_processor import ( CNAttnProcessor2_0 as CNAttnProcessor, ) from .attention_processor import ( IPAttnProcessor2_0 as IPAttnProcessor, ) else: from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor from .resampler import Resampler class ImageProjModel(torch.nn.Module): """投影模型 - 将CLIP图像特征转换为适合UNet交叉注意力的格式""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim # UNet交叉注意力的维度 self.clip_extra_context_tokens = clip_extra_context_tokens # 额外上下文token数量 # 线性投影层,将CLIP嵌入转换为多个额外的上下文token self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) # 标准化层 def forward(self, image_embeds): # 投影CLIP图像嵌入到多个上下文token embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class MLPProjModel(torch.nn.Module): """使用多层感知器的投影模型 - 用于IPAdapterFull变体""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): super().__init__() # 多层感知器,包含两个线性层、GELU激活和层标准化 self.proj = torch.nn.Sequential( torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), torch.nn.GELU(), torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), torch.nn.LayerNorm(cross_attention_dim) ) def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class IPAdapter: def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.num_tokens = num_tokens self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() # load image encoder self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( self.device, dtype=torch.float16 ) self.clip_image_processor = CLIPImageProcessor() # image proj model self.image_proj_model = self.init_proj() self.load_ip_adapter() def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_extra_context_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) return image_proj_model def set_ip_adapter(self): unet = self.pipe.unet attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: attn_procs[name] = IPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens, ).to(self.device, dtype=torch.float16) unet.set_attn_processor(attn_procs) if hasattr(self.pipe, "controlnet"): if isinstance(self.pipe.controlnet, MultiControlNetModel): for controlnet in self.pipe.controlnet.nets: controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) else: self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) def load_ip_adapter(self): if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": state_dict = {"image_proj": {}, "ip_adapter": {}} with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("image_proj."): state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) elif key.startswith("ip_adapter."): state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) else: state_dict = torch.load(self.ip_ckpt, map_location="cpu") self.image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): if pil_image is not None: if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds else: clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale def generate( self, pil_image=None, clip_image_embeds=None, prompt=None, negative_prompt=None, scale=1.0, num_samples=4, seed=None, guidance_scale=7.5, num_inference_steps=30, **kwargs, ): self.set_scale(scale) if pil_image is not None: num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) else: num_prompts = clip_image_embeds.size(0) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( pil_image=pil_image, clip_image_embeds=clip_image_embeds ) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, **kwargs, ).images return images class IPAdapterXL(IPAdapter): """SDXL""" def generate( self, pil_image, prompt=None, negative_prompt=None, scale=1.0, num_samples=4, seed=None, num_inference_steps=30, **kwargs, ): """SDXL专用的生成方法,考虑了SDXL的pooled嵌入""" self.set_scale(scale) num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) self.generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, generator=self.generator, **kwargs, ).images return images class IPAdapterPlus(IPAdapter): """使用细粒度特征的IP-Adapter增强版本""" def init_proj(self): """使用Resampler替代简单的线性投影""" # Resampler是一种更强大的特征重采样模型,能更好地捕获图像中的细节 image_proj_model = Resampler( dim=self.pipe.unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): """使用CLIP模型的倒数第二层隐藏状态作为更丰富的特征""" # 使用CLIP的内部特征而非最终投影,提供更细粒度的图像理解 if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds class IPAdapterFull(IPAdapterPlus): """IP-Adapter with full features""" def init_proj(self): image_proj_model = MLPProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, clip_embeddings_dim=self.image_encoder.config.hidden_size, ).to(self.device, dtype=torch.float16) return image_proj_model class IPAdapterPlusXL(IPAdapter): """SDXL""" def init_proj(self): image_proj_model = Resampler( dim=1280, depth=4, dim_head=64, heads=20, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=torch.float16) return image_proj_model @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=torch.float16) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds def generate( self, pil_image, prompt=None, negative_prompt=None, scale=1.0, num_samples=4, seed=None, num_inference_steps=30, **kwargs, ): self.set_scale(scale) num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) if prompt is None: prompt = "best quality, high quality" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, generator=generator, **kwargs, ).images return images ================================================ FILE: ip_adapter/resampler.py ================================================ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py import math import torch import torch.nn as nn from einops import rearrange from einops.layers.torch import Rearrange # FFN def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) def reshape_tensor(x, heads): bs, length, width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, heads, length, -1) return x class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class Resampler(nn.Module): def __init__( self, dim=1024, depth=8, dim_head=64, heads=16, num_queries=8, embedding_dim=768, output_dim=1024, ff_mult=4, max_seq_len: int = 257, # CLIP tokens + CLS token apply_pos_emb: bool = False, num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence ): super().__init__() self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) self.to_latents_from_mean_pooled_seq = ( nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * num_latents_mean_pooled), Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), ) if num_latents_mean_pooled > 0 else None ) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, x): if self.pos_emb is not None: n, device = x.shape[1], x.device pos_emb = self.pos_emb(torch.arange(n, device=device)) x = x + pos_emb latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) if self.to_latents_from_mean_pooled_seq: meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) latents = torch.cat((meanpooled_latents, latents), dim=-2) for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) def masked_mean(t, *, dim, mask=None): if mask is None: return t.mean(dim=dim) denom = mask.sum(dim=dim, keepdim=True) mask = rearrange(mask, "b n -> b n 1") masked_t = t.masked_fill(~mask, 0.0) return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) ================================================ FILE: ip_adapter/shared_models.py ================================================ import torch from torch import nn import math import os import random import argparse from pathlib import Path import json import itertools import time import torch.nn as nn from diffusers.models.attention_processor import Attention import torch import torch.nn.functional as F import numpy as np class Cross_Attention(nn.Module): def __init__(self, query_dim, # Q projection input dimension context_dim, # K/V projection input dimension heads=8, head_dim=64, value_dim=None, # V dimension after dimensionality reduction (defaults to head_dim) out_dim=None): # Output dimension super().__init__() self.heads = heads self.head_dim = head_dim self.scale = math.sqrt(head_dim) self.value_dim = value_dim if value_dim is not None else head_dim self.out_dim = out_dim if out_dim is not None else heads * self.value_dim # Linear projection layers self.to_q = nn.Linear(query_dim, heads * head_dim) self.to_k = nn.Linear(context_dim, heads * head_dim) self.to_v = nn.Linear(context_dim, heads * self.value_dim) # Optional output projection self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim) def forward(self, query_input, context_input): """ query_input: [B, Q_len, query_dim] context_input: [B, K_len, context_dim] """ B = query_input.size(0) # Project Q, K, V q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, Q_len, head_dim] k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, K_len, head_dim] v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2) # [B, heads, K_len, v_dim] # Attention weights attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [B, heads, Q_len, K_len] attn_probs = F.softmax(attn_scores, dim=-1) # Weighted sum attn_output = torch.matmul(attn_probs, v) # [B, heads, Q_len, v_dim] # Concatenate heads attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim) # [B, Q_len, heads * v_dim] # Output projection output = self.out_proj(attn_output) # [B, Q_len, out_dim] return output class ImageProjModel(torch.nn.Module): """Projection model - converts CLIP image features into a format suitable for UNet cross-attention""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim # UNet cross-attention dimension self.clip_extra_context_tokens = clip_extra_context_tokens # Number of extra context tokens # Linear projection layer, converts CLIP embeddings into multiple extra context tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) # Normalization layer def forward(self, image_embeds): # Project CLIP image embeddings into multiple context tokens embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0): super().__init__() # Cross Attention layer self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32) self.scale=scale # self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64) # Image projection from 1280 to 2560 self.fc1=nn.Linear(hidden_size, hidden_size*2) # Layer Normalization layer self.ln = nn.LayerNorm(hidden_size) # FC layer 1280 to 1280 self.fc2 = nn.Linear(hidden_size, hidden_size) def forward(self, text_embeds,image_embeds): image_embeds=self.fc1(image_embeds) #[1, 2560] image_embeds=image_embeds.reshape(1,4,640) output = self.cross_attention(image_embeds, text_embeds) #[1,4,320] output=output.reshape(1,1280) # Normalize the output using Layer Normalization output=self.ln(output) # FC layer [1,4,2048]->[1,4,2048] output=self.fc2(output) return output*self.scale def load_from_checkpoint(self, ckpt_path: str): from safetensors.torch import load_file from collections import OrderedDict # Load weights file weights = load_file(ckpt_path) # Initialize two dictionaries to store weights for different modules separately image_proj_weights = OrderedDict() attn_weights = OrderedDict() # Separate weights into different modules for k, v in weights.items(): # Process image_proj_model weights (match two possible key name prefixes) if k.startswith("image_proj_model.") or k.startswith("image_proj."): new_key = k.replace("image_proj_model.", "").replace("image_proj.", "") if hasattr(self, "image_proj_model") and hasattr(self.image_proj_model, new_key.split('.')[0]): image_proj_weights[new_key] = v # Process target attention layer weights (match two possible key name formats) elif "down_blocks.2.attentions.1" in k: # Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1 new_key = k.replace("composed_modules.", "").replace("ip_adapter.", "") if hasattr(self, new_key.split('.')[0]): attn_weights[new_key] = v # Load image_proj_model weights (strict mode) if image_proj_weights: self.image_proj_model.load_state_dict(image_proj_weights, strict=True) print(f"Loaded image_proj_model weights: {len(image_proj_weights)} params") # Load attention layer weights (non-strict mode) if attn_weights: # Create temporary ModuleDict to load weights temp_dict = {k: v for k, v in self.named_modules() if "down_blocks.2.attentions.1" in k} temp_model = torch.nn.ModuleDict(temp_dict) missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False) if missing: print(f"Missing keys in attention blocks: {missing}") if unexpected: print(f"Unexpected keys in attention blocks: {unexpected}") print(f"Loaded attention weights: {len(attn_weights)} params") print(f"Successfully loaded target modules from {ckpt_path}") return self ================================================ FILE: ip_adapter/test_resampler.py ================================================ import torch from resampler import Resampler from transformers import CLIPVisionModel BATCH_SIZE = 2 OUTPUT_DIM = 1280 NUM_QUERIES = 8 NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" def main(): image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) embedding_dim = image_encoder.config.hidden_size print(f"image_encoder hidden size: ", embedding_dim) image_proj_model = Resampler( dim=1024, depth=2, dim_head=64, heads=16, num_queries=NUM_QUERIES, embedding_dim=embedding_dim, output_dim=OUTPUT_DIM, ff_mult=2, max_seq_len=257, apply_pos_emb=APPLY_POS_EMB, num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, ) dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) with torch.no_grad(): image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] print("image_embds shape: ", image_embeds.shape) with torch.no_grad(): ip_tokens = image_proj_model(image_embeds) print("ip_tokens shape:", ip_tokens.shape) assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) if __name__ == "__main__": main() ================================================ FILE: ip_adapter/utils.py ================================================ import torch import torch.nn.functional as F import numpy as np from PIL import Image attn_maps = {} def hook_fn(name): def forward_hook(module, input, output): if hasattr(module.processor, "attn_map"): attn_maps[name] = module.processor.attn_map del module.processor.attn_map return forward_hook def register_cross_attention_hook(unet): for name, module in unet.named_modules(): if name.split('.')[-1].startswith('attn2'): module.register_forward_hook(hook_fn(name)) return unet def upscale(attn_map, target_size): attn_map = torch.mean(attn_map, dim=0) attn_map = attn_map.permute(1,0) temp_size = None for i in range(0,5): scale = 2 ** i if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) break assert temp_size is not None, "temp_size cannot is None" attn_map = attn_map.view(attn_map.shape[0], *temp_size) attn_map = F.interpolate( attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode='bilinear', align_corners=False )[0] attn_map = torch.softmax(attn_map, dim=0) return attn_map def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): idx = 0 if instance_or_negative else 1 net_attn_maps = [] for name, attn_map in attn_maps.items(): attn_map = attn_map.cpu() if detach else attn_map attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() attn_map = upscale(attn_map, image_size) net_attn_maps.append(attn_map) net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) return net_attn_maps def attnmaps2images(net_attn_maps): #total_attn_scores = 0 images = [] for attn_map in net_attn_maps: attn_map = attn_map.cpu().numpy() #total_attn_scores += attn_map.mean().item() normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 normalized_attn_map = normalized_attn_map.astype(np.uint8) #print("norm: ", normalized_attn_map.shape) image = Image.fromarray(normalized_attn_map) #image = fix_save_attn_map(attn_map) images.append(image) #print(total_attn_scores) return images def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") def get_generator(seed, device): if seed is not None: if isinstance(seed, list): generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] else: generator = torch.Generator(device).manual_seed(seed) else: generator = None return generator ================================================ FILE: requirements.txt ================================================ absl-py==2.1.0 accelerate==1.0.1 aiofiles==23.2.1 aiohappyeyeballs==2.4.4 aiohttp==3.10.11 aiosignal==1.3.1 annotated-types==0.7.0 anyio==3.7.1 async-timeout==5.0.1 asyncer==0.0.2 attrs==25.3.0 bidict==0.23.1 bitsandbytes==0.45.5 cachetools==5.5.0 certifi==2024.8.30 chainlit==1.1.402 charset-normalizer==3.4.0 chevron==0.14.0 click==8.1.8 contourpy==1.1.1 cycler==0.12.1 dataclasses-json==0.5.14 datasets==3.1.0 Deprecated==1.2.18 diffusers==0.30.0 dill==0.3.8 distro==1.9.0 einops==0.8.0 exceptiongroup==1.2.2 fastapi==0.110.3 ffmpy==0.5.0 filelock==3.16.1 filetype==1.2.0 fonttools==4.57.0 frozenlist==1.5.0 fsspec==2024.9.0 ftfy==6.2.3 google-auth==2.36.0 google-auth-oauthlib==1.0.0 googleapis-common-protos==1.69.2 gradio==4.44.1 gradio_client==1.3.0 grpcio==1.67.1 h11==0.14.0 httpcore==1.0.7 httpx==0.28.1 huggingface-hub==0.25.2 idna==3.10 importlib_metadata==8.5.0 importlib_resources==6.4.5 Jinja2==3.1.4 jiter==0.9.0 kiwisolver==1.4.7 Lazify==0.4.0 literalai==0.0.607 loguru==0.7.3 Markdown==3.7 markdown-it-py==3.0.0 MarkupSafe==2.1.5 marshmallow==3.22.0 matplotlib==3.7.5 mdurl==0.1.2 mpmath==1.3.0 multidict==6.1.0 multiprocess==0.70.16 mypy-extensions==1.0.0 nest-asyncio==1.6.0 networkx==3.1 numpy==1.24.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.6.77 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 openai==1.71.0 opentelemetry-api==1.31.1 opentelemetry-exporter-otlp==1.31.1 opentelemetry-exporter-otlp-proto-common==1.31.1 opentelemetry-exporter-otlp-proto-grpc==1.31.1 opentelemetry-exporter-otlp-proto-http==1.31.1 opentelemetry-instrumentation==0.52b1 opentelemetry-proto==1.31.1 opentelemetry-sdk==1.31.1 opentelemetry-semantic-conventions==0.52b1 orjson==3.10.15 packaging==23.2 pandas==2.0.3 peft==0.13.2 pillow==10.4.0 propcache==0.2.0 protobuf==5.28.3 psutil==6.1.0 pyarrow==17.0.0 pyasn1==0.6.1 pyasn1_modules==0.4.1 pydantic==2.10.6 pydantic_core==2.27.2 pydub==0.25.1 Pygments==2.19.1 PyJWT==2.9.0 pyparsing==3.1.4 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 python-engineio==4.11.2 python-multipart==0.0.9 python-socketio==5.12.1 pytz==2024.2 PyYAML==6.0.2 regex==2024.9.11 requests==2.32.3 requests-oauthlib==2.0.0 rich==14.0.0 rsa==4.9 ruff==0.11.8 safetensors==0.4.5 scipy==1.10.1 semantic-version==2.10.0 shellingham==1.5.4 simple-websocket==1.1.0 six==1.17.0 sniffio==1.3.1 sse-starlette==2.1.3 starlette==0.37.2 sympy==1.13.3 syncer==2.0.3 tensorboard==2.14.0 tensorboard-data-server==0.7.2 tensorboardX==2.6.2.2 timm==1.0.13 tokenizers==0.20.3 tomli==2.2.1 tomlkit==0.12.0 torch==2.4.1 torchaudio==2.4.1 torchvision==0.19.1 tqdm==4.66.5 transformers==4.45.0 triton==3.0.0 typer==0.15.3 typing-inspect==0.9.0 typing_extensions==4.12.2 tzdata==2024.2 uptrace==1.31.0 urllib3==2.2.3 uvicorn==0.25.0 watchfiles==0.20.0 wcwidth==0.2.13 websockets==12.0 Werkzeug==3.0.6 wrapt==1.17.2 wsproto==1.2.0 xformers==0.0.28.post1 xxhash==3.5.0 yarl==1.15.2 zipp==3.20.2 ================================================ FILE: run.sh ================================================ accelerate launch --gpu_ids 0 --num_processes 1 --mixed_precision "fp16" \ train.py \ --pretrained_model_name_or_path="your path" \ --pretrained_ip_adapter_path="your path" \ --image_encoder_path="your path" \ --data_root_path='your path' \ --mixed_precision="fp16" \ --resolution=512 \ --train_batch_size=1 \ --dataloader_num_workers=4 \ --learning_rate=2.5e-04 \ --data_json_file="your path" \ --weight_decay=0.01 \ --output_dir="your path" \ --save_steps=100 \ --num_train_epochs 2100 \ --composed_inter_dim=2560 \ --composed_cross_heads=8 \ --composed_reshape_blocks=8 \ --composed_cross_value_dim=64 ================================================ FILE: sdxl-fine-tuning/data/train.json ================================================ [ { "image_file": "your image", "text": " ", "extra_text": "your caption", "comments": { "image_file": "five_cats.png", "text": "", "extra_text": "Five cats" } } ] ================================================ FILE: shared_models.py ================================================ import torch from torch import nn import math import os import random import argparse from pathlib import Path import json import itertools import time import torch.nn as nn from diffusers.models.attention_processor import Attention import torch import torch.nn.functional as F import numpy as np class Cross_Attention(nn.Module): def __init__(self, query_dim, # Q projection input dimension context_dim, # K/V projection input dimension heads=8, head_dim=64, value_dim=None, # V dimension after dimensionality reduction (defaults to head_dim) out_dim=None): # Output dimension super().__init__() self.heads = heads self.head_dim = head_dim self.scale = math.sqrt(head_dim) self.value_dim = value_dim if value_dim is not None else head_dim self.out_dim = out_dim if out_dim is not None else heads * self.value_dim # Linear projection layers self.to_q = nn.Linear(query_dim, heads * head_dim) self.to_k = nn.Linear(context_dim, heads * head_dim) self.to_v = nn.Linear(context_dim, heads * self.value_dim) # Optional output projection self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim) def forward(self, query_input, context_input): """ query_input: [B, Q_len, query_dim] context_input: [B, K_len, context_dim] """ B = query_input.size(0) # Project Q, K, V q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, Q_len, head_dim] k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2) # [B, heads, K_len, head_dim] v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2) # [B, heads, K_len, v_dim] # Attention weights attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [B, heads, Q_len, K_len] attn_probs = F.softmax(attn_scores, dim=-1) # Weighted sum attn_output = torch.matmul(attn_probs, v) # [B, heads, Q_len, v_dim] # Concatenate heads attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim) # [B, Q_len, heads * v_dim] # Output projection output = self.out_proj(attn_output) # [B, Q_len, out_dim] return output class ImageProjModel(torch.nn.Module): """Projection model - converts CLIP image features into a format suitable for UNet cross-attention""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim # UNet cross-attention dimension self.clip_extra_context_tokens = clip_extra_context_tokens # Number of extra context tokens # Linear projection layer, converts CLIP embeddings into multiple extra context tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) # Normalization layer def forward(self, image_embeds): # Project CLIP image embeddings into multiple context tokens embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class Composed_Attention(torch.nn.Module):#Number_Class_crossAttention def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0): super().__init__() # Cross Attention layer self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32) self.scale=scale # self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64) # Image projection from 1280 to 2560 self.fc1=nn.Linear(hidden_size, hidden_size*2) # Layer Normalization layer self.ln = nn.LayerNorm(hidden_size) # FC layer 1280 to 1280 self.fc2 = nn.Linear(hidden_size, hidden_size) def forward(self, text_embeds,image_embeds): image_embeds=self.fc1(image_embeds) #[1, 2560] image_embeds=image_embeds.reshape(1,4,640) output = self.cross_attention(image_embeds, text_embeds) #[1,4,320] output=output.reshape(1,1280) # Normalize the output using Layer Normalization output=self.ln(output) # FC layer [1,4,2048]->[1,4,2048] output=self.fc2(output) return output*self.scale def load_from_checkpoint(self, ckpt_path: str): from safetensors.torch import load_file from collections import OrderedDict # Load weights file weights = load_file(ckpt_path) # Initialize two dictionaries to store weights for different modules separately image_proj_weights = OrderedDict() attn_weights = OrderedDict() # Separate weights into different modules for k, v in weights.items(): # Process image_proj_model weights (match two possible key name prefixes) if k.startswith("image_proj_model.") or k.startswith("image_proj."): new_key = k.replace("image_proj_model.", "").replace("image_proj.", "") if hasattr(self, "image_proj_model") and hasattr(self.image_proj_model, new_key.split('.')[0]): image_proj_weights[new_key] = v # Process target attention layer weights (match two possible key name formats) elif "down_blocks.2.attentions.1" in k: # Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1 new_key = k.replace("composed_modules.", "").replace("ip_adapter.", "") if hasattr(self, new_key.split('.')[0]): attn_weights[new_key] = v # Load image_proj_model weights (strict mode) if image_proj_weights: self.image_proj_model.load_state_dict(image_proj_weights, strict=True) print(f"Loaded image_proj_model weights: {len(image_proj_weights)} params") # Load attention layer weights (non-strict mode) if attn_weights: # Create temporary ModuleDict to load weights temp_dict = {k: v for k, v in self.named_modules() if "down_blocks.2.attentions.1" in k} temp_model = torch.nn.ModuleDict(temp_dict) missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False) if missing: print(f"Missing keys in attention blocks: {missing}") if unexpected: print(f"Unexpected keys in attention blocks: {unexpected}") print(f"Loaded attention weights: {len(attn_weights)} params") print(f"Successfully loaded target modules from {ckpt_path}") return self ================================================ FILE: test.py ================================================ import torch from diffusers import StableDiffusionXLPipeline from PIL import Image from ip_adapter import IPAdapterXL # Assuming IPAdapterXL is correctly defined in ip_adapter from train import HarmonyAttention # Assuming HarmonyAttention is correctly defined import os ''' The following parameters must be manually adjusted to match the training parameters. ''' ckpt_inter_dim = 2560 ckpt_cross_heads = 8 ckpt_reshape_blocks = 8 ckpt_cross_value_dim = 64 # Image generation function def generate_image(input_path, prompt, extra_text, output_path="output.png"): print(f"\nGenerating image for: {prompt} + {extra_text}") # Prepare input image input_image = Image.open(input_path).resize((512, 512)) # Generate image images = ip_model.generate( pil_image=input_image, prompt=prompt, negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", scale=1.0, guidance_scale=5.0, num_samples=1, num_inference_steps=30, seed=42, extra_text=extra_text, number_class_crossattention=number_class_crossattention # This should be the HarmonyAttention instance ) # Save result images[0].save(output_path) print(f"Saved generated image to: {output_path}") return images[0] if __name__ == "__main__": input_image = "your path to inputimage" # Replace with your actual input image path device = "cuda:2" # Model path configuration base_model_path = "your path" # Replace with your actual base model path image_encoder_path = "your path" # Replace with your actual image encoder path fine_tuned_ckpt = "fine_tuned model path" # Path to the fine-tuned weights (ip_adapter.bin or similar) save_root = os.path.join( 'your path', # Replace with your desired save directory fine_tuned_ckpt.split('/')[3] if len(fine_tuned_ckpt.split('/')) > 3 else "default_folder1", fine_tuned_ckpt.split('/')[4] if len(fine_tuned_ckpt.split('/')) > 4 else "default_folder2", ) # Create path (if it doesn't exist) os.makedirs(save_root, exist_ok=True) print(f"Save directory created at: {save_root}") # Load SDXL base model print("Loading base SDXL model...") pipe = StableDiffusionXLPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, add_watermarker=False, ) pipe.enable_vae_tiling() pipe.to(device) # Define the fusion method to be used fusion_method = "cross_attention" # Options: "cross_attention", "qformer", "mlp" # Initialize HarmonyAttention (this is the `number_class_crossattention` module) print("Initializing HarmonyAttention module...") number_class_crossattention = HarmonyAttention( # Renamed for clarity, this is your custom attention module image_hidden_size=1280, # Keep unchanged or match your training text_context_dim=2048, # Keep unchanged or match your training inter_dim=ckpt_inter_dim, cross_heads=ckpt_cross_heads, reshape_blocks=ckpt_reshape_blocks, cross_value_dim=ckpt_cross_value_dim, scale=1.0, # Keep unchanged or adjust as needed fusion_method=fusion_method # Add fusion method selection ).to(device).half() # Initialize IP-Adapter print("Initializing IP-Adapter with target blocks...") ip_model = IPAdapterXL( pipe, image_encoder_path, fine_tuned_ckpt, device, target_blocks=["down_blocks.2.attentions.1"], # Or your specific target blocks num_tokens=4, # Or your specific number of tokens inference=True, number_class_crossattention=number_class_crossattention # Pass the initialized HarmonyAttention module here ) print("HarmonyAttention weights are expected to be loaded as part of the IP-Adapter checkpoint.") generate_image( input_path=input_image, prompt="lions", extra_text="eight sheep", # Use the caption from training output_path=os.path.join(save_root, os.path.basename(input_image)) # Safer way to construct output path ) ================================================ FILE: train.py ================================================ import os import random import argparse from pathlib import Path import json import itertools import time import torch.nn as nn from diffusers.models.attention_processor import Attention import torch import torch.nn.functional as F import numpy as np from collections import OrderedDict from torchvision import transforms from PIL import Image from transformers import CLIPImageProcessor from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection from safetensors import safe_open from shared_models import ImageProjModel from baseline import QFormer,MLP,AttentionFusionWrapper from safetensors.torch import save_file, load_file from ip_adapter.utils import is_torch2_available if is_torch2_available(): from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor else: from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor from ip_adapter.attention_processor import Cross_Attention def count_model_params(model): return sum([p.numel() for p in model.parameters()]) / 1e6 # Dataset class MyDataset(torch.utils.data.Dataset): def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): super().__init__() self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.size = size self.center_crop = center_crop self.i_drop_rate = i_drop_rate self.t_drop_rate = t_drop_rate self.ti_drop_rate = ti_drop_rate self.image_root_path = image_root_path self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] self.transform = transforms.Compose([ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) self.clip_image_processor = CLIPImageProcessor() def __getitem__(self, idx): item = self.data[idx] text = item["text"] text_extra = item['extra_text'] image_file = item["image_file"] # read image raw_image = Image.open(os.path.join(self.image_root_path, image_file)) # original size original_width, original_height = raw_image.size original_size = torch.tensor([original_height, original_width]) image_tensor = self.transform(raw_image.convert("RGB")) # random crop delta_h = image_tensor.shape[1] - self.size delta_w = image_tensor.shape[2] - self.size assert not all([delta_h, delta_w]) if self.center_crop: top = delta_h // 2 left = delta_w // 2 else: top = np.random.randint(0, delta_h + 1) left = np.random.randint(0, delta_w + 1) image = transforms.functional.crop( image_tensor, top=top, left=left, height=self.size, width=self.size ) crop_coords_top_left = torch.tensor([top, left]) clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values # drop drop_image_embed = 0 rand_num = random.random() if rand_num < self.i_drop_rate: drop_image_embed = 1 elif rand_num < (self.i_drop_rate + self.t_drop_rate): text = "" elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): text = "" drop_image_embed = 1 # get text and tokenize text_input_ids = self.tokenizer( text, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids text_input_ids_2 = self.tokenizer_2( text, max_length=self.tokenizer_2.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids #extra text text_extra_input_ids=self.tokenizer( text_extra, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids text_extra_input_ids_2 = self.tokenizer_2( text_extra, max_length=self.tokenizer_2.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return { "image": image, "text_input_ids": text_input_ids, "text_input_ids_2": text_input_ids_2, "text_extra_input_ids":text_extra_input_ids, "text_extra_input_ids_2":text_extra_input_ids_2, "clip_image": clip_image, "drop_image_embed": drop_image_embed, "original_size": original_size, "crop_coords_top_left": crop_coords_top_left, "target_size": torch.tensor([self.size, self.size]), } def __len__(self): return len(self.data) def collate_fn(data): images = torch.stack([example["image"] for example in data]) text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0) clip_images = torch.cat([example["clip_image"] for example in data], dim=0) drop_image_embeds = [example["drop_image_embed"] for example in data] original_size = torch.stack([example["original_size"] for example in data]) crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data]) target_size = torch.stack([example["target_size"] for example in data]) #extra text text_extra_input_ids=torch.stack([example["text_extra_input_ids"] for example in data]) text_extra_input_ids_2=torch.stack([example["text_extra_input_ids_2"] for example in data]) return { "images": images, "text_input_ids": text_input_ids, "text_input_ids_2": text_input_ids_2, "clip_images": clip_images, "drop_image_embeds": drop_image_embeds, "original_size": original_size, "crop_coords_top_left": crop_coords_top_left, "target_size": target_size, "text_extra_input_ids":text_extra_input_ids, "text_extra_input_ids_2":text_extra_input_ids_2, } class HarmonyAttention(nn.Module): def __init__(self, image_hidden_size=1280, # Input image feature dimension text_context_dim=2048, # Input text context feature dimension inter_dim=2560, # Intermediate projection dimension cross_heads=10, # Number of cross-attention heads reshape_blocks=8, # Number of image feature blocks cross_value_dim=64, # Value dimension per head after dimensionality reduction scale=1.0, # Output scaling factor fusion_method="qformer"): # Fusion method selection: mlp, cross_attention, qformer super().__init__() self.scale = scale self.reshape_blocks = reshape_blocks self.cross_query_dim = inter_dim // reshape_blocks self.fusion_method = fusion_method self.image_hidden_size = image_hidden_size self.text_context_dim = text_context_dim # Image projection required by all methods self.fc1 = nn.Linear(image_hidden_size, inter_dim) print(fusion_method) if fusion_method == "cross_attention": # 1. Cross-attention self.fusion_text_image = Cross_Attention( query_dim=self.cross_query_dim, context_dim=text_context_dim, heads=cross_heads, value_dim=cross_value_dim ) elif fusion_method == "qformer": # 2. Q-Former TODO self.fusion_text_image = QFormer( num_queries=16, hidden_dim=self.cross_query_dim, num_layers=1, num_heads=cross_heads ) elif fusion_method=='mlp': #3. MLP TODO self.fusion_text_image = MLP( fused_dim=self.cross_query_dim ) elif fusion_method=='gated-attention': #4. Gated-attention self.fusion_text_image = AttentionFusionWrapper( fused_dim=self.cross_query_dim ) flattened_dim = cross_value_dim * cross_heads * reshape_blocks self.ln = nn.LayerNorm(flattened_dim) self.fc2 = nn.Linear(flattened_dim, image_hidden_size) def forward(self, text_embeds, image_embeds): """ Args: text_embeds: [B, T, text_context_dim] image_embeds: [B, image_hidden_size] Returns: out: [B, image_hidden_size], image features fused with text conditions """ B = image_embeds.size(0) # Map image features to intermediate dimension and reshape into multiple blocks x = self.fc1(image_embeds) # [B, inter_dim] x = x.view(B, self.reshape_blocks, self.cross_query_dim) # [B, N_blocks, query_dim] # Cross-Attention: interaction between image blocks and text print(self.fusion_text_image) attended = self.fusion_text_image(x, text_embeds) # [B, N_blocks, value_dim * heads] print(attended.shape) # Flatten, normalize, and project back attended = attended.view(B, -1) # [B, flattened_dim] out = self.ln(attended) out = self.fc2(out) * self.scale return out class IPAdapter(torch.nn.Module): """IP-Adapter""" def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None, inter_dim=None, cross_heads=None, reshape_blocks=None, cross_value_dim=None, fusion_method="cross_attention"): # Fusion method parameter super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules # Directly use the fusion_method parameter, no need to convert to a boolean flag self.composed_modules = HarmonyAttention( image_hidden_size=1280, # Image feature dimension fixed text_context_dim=2048, # Text context dimension fixed inter_dim=inter_dim, cross_heads=cross_heads, reshape_blocks=reshape_blocks, cross_value_dim=cross_value_dim, scale=1.0, # Scaling factor fixed fusion_method=fusion_method # Directly pass the fusion method parameter ) if ckpt_path is not None: self.load_from_checkpoint(ckpt_path) def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds, text_extra_embeds): composed_embeds = self.composed_modules(text_extra_embeds, image_embeds) image_embeds = image_embeds + composed_embeds ip_tokens = self.image_proj_model(image_embeds) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) # Predict the noise residual noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample return noise_pred def load_from_checkpoint(self, ckpt_path: str): # Calculate original checksums orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) if os.path.splitext(ckpt_path)[-1] == ".safetensors": state_dict = {"image_proj": {}, "ip_adapter": {}} with safe_open(ckpt_path, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("image_proj."): state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) elif key.startswith("ip_adapter."): state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) else: state_dict = torch.load(ckpt_path, map_location="cpu") # Load state dict for image_proj_model and adapter_modules self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False) # Calculate new checksums new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) # Verify if the weights have changed assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" print(f"Successfully loaded weights from checkpoint {ckpt_path}") def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--pretrained_ip_adapter_path", type=str, default=None, help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", ) parser.add_argument( "--data_json_file", type=str, default=None, required=True, help="Training data", ) parser.add_argument( "--data_root_path", type=str, default="", required=True, help="Training data root path", ) parser.add_argument( "--image_encoder_path", type=str, default=None, required=True, help="Path to CLIP image encoder", ) parser.add_argument( "--output_dir", type=str, default="sd-ip_adapter", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images" ), ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Learning rate to use.", ) parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--num_train_epochs", type=int, default=10000) parser.add_argument( "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument( "--save_steps", type=int, default=2000, help=( "Save a checkpoint of the training state every X updates" ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--composed_inter_dim", type=int, default=None, help="HarmonyAttention's intermediate projection dimension [e.g., 1280, 2560]." ) parser.add_argument( "--composed_cross_heads", type=int, default=None, help="Number of cross-attention heads in HarmonyAttention [e.g., 8, 10]." ) parser.add_argument( "--composed_reshape_blocks", type=int, default=None, help="Number of image feature blocks in HarmonyAttention [e.g., 4, 8]." ) parser.add_argument( "--composed_cross_value_dim", type=int, default=None, help="Value dimension per head after dimensionality reduction in HarmonyAttention [e.g., 32, 64]." ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank return args def main(): # Parse command-line arguments args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) # Configure accelerate for mixed-precision and distributed training accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Create output directory if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load various pre-trained model components # Noise addition control, tokenizer, text encoder, VAE, UNet, etc. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2") vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) # Freeze base model parameters, only train the adapter part unet.requires_grad_(False) vae.requires_grad_(False) text_encoder.requires_grad_(False) text_encoder_2.requires_grad_(False) image_encoder.requires_grad_(False) # Initialize IP-Adapter model components # Image projection model maps CLIP image embeddings to UNet's cross-attention dimension num_tokens = 4 # Number of extra context tokens image_proj_model = ImageProjModel( cross_attention_dim=unet.config.cross_attention_dim, clip_embeddings_dim=image_encoder.config.projection_dim, clip_extra_context_tokens=num_tokens, ) # Initialize adapter attention processors # Create appropriate attention processors for each attention block in UNet # init adapter modules attn_procs = {} unet_sd = unet.state_dict() # Initialize UNet's attention + IP-Adapter's cross_attention parameters for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: layer_name = name.split(".processor")[0] # Layers that need to add IP join additional attention parameters if 'down_blocks.2.attentions.1' in name: weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens, skip=False) attn_procs[name].load_state_dict(weights, strict=False) else: attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens, skip=True) # Load all attention processing into UNet unet.set_attn_processor(attn_procs) # Extract all attention layers into a list adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) # Create the complete IP-Adapter model ip_adapter = IPAdapter( unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path, inter_dim=args.composed_inter_dim, # Use command-line arguments cross_heads=args.composed_cross_heads, # Use command-line arguments reshape_blocks=args.composed_reshape_blocks, # Use command-line arguments cross_value_dim=args.composed_cross_value_dim, # Use command-line arguments fusion_method='cross_attention' ) # ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) # Set data type for mixed-precision training weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move models to the appropriate device and convert to the appropriate data type vae.to(accelerator.device) # VAE uses fp32 for better stability text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) image_encoder.to(accelerator.device, dtype=weight_dtype) # Set optimizer - only optimize IP-Adapter components params_to_opt = itertools.chain(ip_adapter.adapter_modules.parameters(), ip_adapter.composed_modules.parameters()) optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) accelerator.print("Trainable parameters: adapter_modules:{:.2f}M, composed_modules:{:.2f}M".format( count_model_params(ip_adapter.adapter_modules), count_model_params(ip_adapter.composed_modules))) # Create dataset and dataloader train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, image_root_path=args.data_root_path) train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # Prepare model, optimizer, and dataloader with accelerator ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) # Start training loop global_step = 0 for epoch in range(0, args.num_train_epochs): begin = time.perf_counter() for step, batch in enumerate(train_dataloader): load_data_time = time.perf_counter() - begin with accelerator.accumulate(ip_adapter): # Convert images to latent space using VAE with torch.no_grad(): # SDXL's VAE uses fp32 for better numerical stability latents = vae.encode(batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample() latents = latents * vae.config.scaling_factor latents = latents.to(accelerator.device, dtype=weight_dtype) # Generate random noise to add to the latent representation noise = torch.randn_like(latents) if args.noise_offset: # Use noise offset technique to improve training stability noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(accelerator.device, dtype=weight_dtype) bsz = latents.shape[0] # Sample random timesteps for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise according to timesteps (forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get CLIP image embeddings with torch.no_grad(): image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds # Apply image embedding dropout strategy image_embeds_ = [] for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]): if drop_image_embed == 1: image_embeds_.append(torch.zeros_like(image_embed)) else: image_embeds_.append(image_embed) image_embeds = torch.stack(image_embeds_) # Get text embeddings (SDXL uses two text encoders) with torch.no_grad(): encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True) text_embeds = encoder_output.hidden_states[-2] encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True) pooled_text_embeds = encoder_output_2[0] text_embeds_2 = encoder_output_2.hidden_states[-2] text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # Concatenate outputs of the two text encoders #extra text encoder_extra_output = text_encoder(batch['text_extra_input_ids'].to(accelerator.device),output_hidden_states=True) text_extra_embeds = encoder_extra_output.hidden_states[-2] encoder_extra_output_2 = text_encoder_2(batch['text_extra_input_ids_2'].to(accelerator.device),output_hidden_states=True) text_extra_embeds_2 = encoder_extra_output_2.hidden_states[-2] text_extra_embeds=torch.concat([text_extra_embeds,text_extra_embeds_2],dim=-1) # Add extra conditions required by SDXL (image size and crop info) add_time_ids = [ batch["original_size"].to(accelerator.device), batch["crop_coords_top_left"].to(accelerator.device), batch["target_size"].to(accelerator.device), ] add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype) unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} # Predict noise using IP-Adapter noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds, text_extra_embeds) # Calculate MSE loss (between predicted noise and actual added noise) loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") # Gather loss in distributed training avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() # Backpropagation and optimizer step accelerator.backward(loss) optimizer.step() optimizer.zero_grad() # Print training information if accelerator.is_main_process: print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( epoch, step, load_data_time, time.perf_counter() - begin, avg_loss)) global_step += 1 # Save checkpoint periodically if global_step % args.save_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path, safe_serialization=False) begin = time.perf_counter() if __name__ == "__main__": main()