[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout\n\n\n\n<a href='https://revive234.github.io/IMAGHarmony.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>\n<a href='https://arxiv.org/pdf/2506.01949'><img src='https://img.shields.io/badge/Technique-Report-red'></a>\n<a href='https://huggingface.co/kkkkggg/IMAGHarmony'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>\n<a href=''><img src='https://img.shields.io/badge/Dataset-HarmonyBench-orange'></a>\n\n\n\n## 🗓️ Release\n- [2025/5/30] 🔥 We released the [technical report](https://arxiv.org/pdf/2506.01949) of IMAGHarmony.\n- [2025/5/28] 🔥 We release the train and inference code of IMAGHarmony.\n- [2025/5/17] 🎉 We launch the [project page](https://revive234.github.io/IMAGHarmony.github.io/) of IMAGHarmony.\n\n\n\n\n\n\n\n\n## 💡 Introduction\nIMAGHarmony tackles the challenge of controllable image editing in multi-object scenes, where existing models struggle with aligning object quantity and spatial layout.\nTo this end, IMAGHarmony introduces a structure-aware framework for quantity-and-layout consistent image editing (QL-Edit), enabling precise control over object count, category, and arrangement.\nWe propose a harmony aware (HA) mudule to jointly model object structure and semantics, and a preference-guided noise selection (PNS) strategy to stabilize generation by selecting semantically aligned initial noise.\nOur method is trained and evaluated on HarmonyBench, a newly curated benchmark with diverse editing scenarios.\n\n![architecture](./assets/1.png)\n\n## 🚀 HarmonyBench Dataset Demo\n\n\n![dataset_demo](./assets/harmonybench.jpg)\n## 🚀 Examples\n\n![results_1](./assets/sotacomp.jpg)\n\n![results_2](./assets/multi.jpg)\n\n\n### Dual-Category Editing\n![results_5](./assets/3edit.jpg)\n\n\n\n\n\n## 🔧 Requirements\n\n- Python>=3.8\n- [PyTorch>=2.0.0](https://pytorch.org/)\n- cuda>=11.8\n```\nconda create --name IMAGHarmony python=3.8.18\nconda activate IMAGHarmony\n\n# Install requirements\npip install -r requirements.txt\n```\n## 🌐 Download Models\n\nYou can download our models from [Huggingface](https://huggingface.co/kkkkggg/IMAGHarmony). You can download the other component models from the original repository, as follows.\n- [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)\n- [stable-diffusion-XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n\n## 🚀 How to train\n```\n# Please download the HarmonyBench data first or prepare your own images\n# and modify the path in run.sh\n## Write caption of your image in your train.json file \n# start training\n\nsh train.sh\n```\n## 🚀 How to test\n```\n#Please convert your checkpionts\npython conver_bin.py\n\n#Please fill in your path in test.py\n#then run\n\npython test.py\n```\nOr you may like to test it on gradio\n```\npython demo.py\n```\n\n\n## Acknowledgement\nWe would like to thank the contributors to the [Instantstyle](https://github.com/instantX-research/InstantStyle) and [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) repositories, for their open research and exploration.\n\nThe IMAGHarmony code is available for both academic and commercial use. Users are permitted to generate images using this tool, provided they comply with local laws and exercise responsible use. The developers disclaim all liability for any misuse or unlawful activity by users.\n## Citation\nIf you find IMAGHarmony useful for your research and applications, please cite using this BibTeX:\n\n```bibtex\n@misc{shen2025imagharmonycontrollableimageediting,\n      title={IMAGHarmony: Controllable Image Editing with Consistent Object Quantity and Layout}, \n      author={Fei Shen and Yutong Gao and Jian Yu and Xiaoyu Du and Jinhui Tang},\n      year={2025},\n      eprint={2506.01949},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2506.01949}, \n}\n```\n\n## 🕒 TODO List\n- [x] Paper\n- [x] Train Code\n- [x] Inference Code\n- [ ] HarmonyBench Dataset\n- [ ] Model Weights\n\n## 👉 **Our other projects:**  \n- [IMAGEdit](https://github.com/XWH-A/IMAGEdit): Training-Free Controllable Video Editing with Consistent Object Layout.  [可控多目标视频编辑]\n- [IMAGDressing](https://github.com/muzishen/IMAGDressing): Controllable dressing generation. [可控穿衣生成]\n- [IMAGGarment](https://github.com/muzishen/IMAGGarment): Fine-grained controllable garment generation.  [可控服装生成]\n- [IMAGHarmony](https://github.com/muzishen/IMAGHarmony): Controllable image editing with consistent object layout.  [可控多目标图像编辑]\n- [IMAGPose](https://github.com/muzishen/IMAGPose): Pose-guided person generation with high fidelity.  [可控多模式人物生成]\n- [RCDMs](https://github.com/muzishen/RCDMs): Rich-contextual conditional diffusion for story visualization.  [可控故事生成]\n- [PCDMs](https://github.com/tencent-ailab/PCDMs): Progressive conditional diffusion for pose-guided image synthesis. [可控人物生成]\n- [V-Express](https://github.com/tencent-ailab/V-Express/): Explores strong and weak conditional relationships for portrait video generation. [可控数字人生成]\n- [FaceShot](https://github.com/open-mmlab/FaceShot/): Talkingface plugin for any character. [可控动漫数字人生成]\n- [CharacterShot](https://github.com/Jeoyal/CharacterShot): Controllable and consistent 4D character animation framework. [可控4D角色生成]\n- [StyleTailor](https://github.com/mahb-THU/StyleTailor): An Agent for personalized fashion styling. [个性化时尚Agent]\n- [SignVip](https://github.com/umnooob/signvip/): Controllable sign language video generation. [可控手语生成]\n\n## 📨 Contact\nIf you have any questions, please feel free to contact with us at shenfei140721@126.com and yutonggaokkk@njust.edu.cn.\n"
  },
  {
    "path": "baseline.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass QFormer(nn.Module):\n    def __init__(self, \n                 hidden_dim=768,\n                 num_queries=16,\n                 num_layers=6,\n                 num_heads=12,\n                 image_feat_dim=320,\n                 text_feat_dim=2048,\n                 add_modality_embedding=True):\n        super(QFormer, self).__init__()\n        \n        self.hidden_dim = hidden_dim\n        self.num_queries = num_queries\n        self.add_modality_embedding = add_modality_embedding\n\n        # Learnable query tokens: [1, num_queries, D]\n        self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim))\n\n        # Modality type embeddings (0=image, 1=text)\n        if add_modality_embedding:\n            self.modality_embed = nn.Embedding(2, hidden_dim)\n        self.image_proj = nn.Linear(image_feat_dim, hidden_dim)\n        self.text_proj = nn.Linear(text_feat_dim, hidden_dim)\n        # Transformer encoder: Q interacts with K/V (image + text)\n        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)\n        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n    def forward(self, image_feat, text_feat):\n        \"\"\"\n        image_feat: [B, T_img, D]\n        text_feat:  [B, T_txt, D]\n        \"\"\"\n        \n        B = image_feat.size(0)\n        print(image_feat.shape,text_feat.shape)\n        image_feat = self.image_proj(image_feat)\n        text_feat = self.text_proj(text_feat)\n        # Concatenate image + text features as K/V\n        kv = torch.cat([image_feat, text_feat], dim=1)  # [B, T_img + T_txt, D]\n\n        # Add modality type embedding\n        if self.add_modality_embedding:\n            T_img = image_feat.size(1)\n            T_txt = text_feat.size(1)\n            modality_ids = torch.cat([\n                torch.zeros(T_img, dtype=torch.long),\n                torch.ones(T_txt, dtype=torch.long)\n            ], dim=0).to(image_feat.device)  # [T_img + T_txt]\n            modality_embed = self.modality_embed(modality_ids)  # [T_img + T_txt, D]\n            kv = kv + modality_embed.unsqueeze(0)  # broadcast to [B, T, D]\n\n        # Expand learnable query tokens to batch size\n        queries = self.query_tokens.expand(B, -1, -1)  # [B, N_query, D]\n\n        # Q-Former: let queries attend to K/V\n        # Transformer requires concat(Q, K/V)\n        input_seq = torch.cat([queries, kv], dim=1)  # [B, N_query + T, D]\n        output = self.transformer(input_seq)  # [B, N_query + T, D]\n\n        # Return only the updated query tokens\n        return output[:, :self.num_queries, :]  # [B, N_query, D]\n    \n\nclass MLP(nn.Module):\n    def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header =16):\n        super().__init__()\n        self.image_proj = nn.Linear(image_dim, fused_dim)\n        self.text_proj = nn.Linear(text_dim, fused_dim)\n        self.num_header = num_header\n        self.mlp = nn.Sequential(\n            nn.Linear(2 * fused_dim, fused_dim),\n            nn.ReLU(),\n            nn.Linear(fused_dim, fused_dim),\n            nn.ReLU(),\n            nn.Linear(fused_dim,fused_dim*16)\n        )\n        self.fused_dim=fused_dim\n    def forward(self, image_feat, text_feat):\n        \"\"\"\n        image_feat: [B, T_img, image_dim]\n        text_feat: [B, T_txt, text_dim]\n        \"\"\"\n \n        image_repr = image_feat.mean(dim=1)  # [B, image_dim]\n        text_repr = text_feat.mean(dim=1)    # [B, text_dim]\n\n\n        image_proj = self.image_proj(image_repr)  # [B, fused_dim]\n        text_proj = self.text_proj(text_repr)     # [B, fused_dim]\n\n\n        fused = torch.cat([image_proj, text_proj], dim=-1)  # [B, 2*fused_dim]\n        \n        output = self.mlp(fused).reshape(-1,self.num_header,self.fused_dim)  # [B, fused_dim]\n        return output\n    \n\n\n\nclass GatedAttentionFusion(nn.Module):\n    def __init__(self, input_dim=768, hidden_dim=512):\n        super().__init__()\n\n        self.gate_mlp = nn.Sequential(\n            nn.Linear(2 * input_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, 1),\n            nn.Sigmoid()  #  alpha ∈ [0, 1]\n        )\n\n    def forward(self, img_feat, txt_feat):\n        \"\"\"\n        img_feat: [B, D]\n        txt_feat: [B, D]\n        \"\"\"\n        fused_input = torch.cat([img_feat, txt_feat], dim=-1)  # [B, 2D]\n        alpha = self.gate_mlp(fused_input)  # [B, 1]\n\n\n        fused = alpha * img_feat + (1 - alpha) * txt_feat  # [B, D]\n        return fused\n\nclass AttentionFusionWrapper(nn.Module):\n    def __init__(self, image_dim=320, text_dim=2048, fused_dim=768,num_header=16):\n        super().__init__()\n        self.img_proj = nn.Linear(image_dim, fused_dim)\n        self.txt_proj = nn.Linear(text_dim, fused_dim)\n        self.fusion = GatedAttentionFusion(input_dim=fused_dim)\n        self.num_header = num_header\n        self.fused_dim = fused_dim\n        self.dim_transfer = nn.Linear(fused_dim,fused_dim*self.num_header)\n    def forward(self, image_feat, text_feat):\n        \"\"\"\n        image_feat: [B, T_img, 320]\n        text_feat:  [B, T_txt, 2048]\n        \"\"\"\n        # Mean pooling\n        img_global = image_feat.mean(dim=1)  # [B, 320]\n        txt_global = text_feat.mean(dim=1)  # [B, 2048]\n\n        # Linear projection\n        img_proj = self.img_proj(img_global)  # [B, 768]\n        txt_proj = self.txt_proj(txt_global)  # [B, 768]\n\n        # Gated fusion\n        fused = self.fusion(img_proj, txt_proj)  # [B, 768]\n        fused = self.dim_transfer(fused).reshape(-1,self.num_header,self.fused_dim)\n        return fused\n"
  },
  {
    "path": "convert_bin.py",
    "content": "import os\nimport torch\nfrom collections import OrderedDict\n\ndef convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path):\n    \n    if not os.path.exists(pytorch_model_path):\n        print(f\"  [Warning] Source file not found, skipping: {pytorch_model_path}\")\n        return False\n\n    print(f\"  Converting: {pytorch_model_path}\")\n    try:\n\n        sd = torch.load(pytorch_model_path, map_location=\"cpu\")\n\n        image_proj_sd = OrderedDict()\n        ip_sd = OrderedDict()\n        composed_sd = OrderedDict()\n        \n\n        for k in sd:\n            if k.startswith(\"image_proj_model.\"):\n                image_proj_sd[k.replace(\"image_proj_model.\", \"\")] = sd[k]\n            elif k.startswith(\"adapter_modules.\"):\n\n                ip_sd[k.replace(\"adapter_modules.\", \"\")] = sd[k] \n            elif k.startswith(\"composed_modules.\"):\n                composed_sd[k.replace(\"composed_modules.\", \"\")] = sd[k]\n\n\n        if not image_proj_sd and not ip_sd and not composed_sd:\n             print(f\"  [Warning] No expected keys (image_proj_model, adapter_modules, composed_modules) found in {pytorch_model_path}. Skipping save.\")\n             return False\n\n\n        final_sd = {\n            \"image_proj\": image_proj_sd, \n            \"ip_adapter\": ip_sd, \n            'composed_adapter': composed_sd\n        }\n        \n\n        torch.save(final_sd, output_ip_adapter_path)\n        print(f\"  Successfully saved: {output_ip_adapter_path}\")\n        return True\n\n    except Exception as e:\n        print(f\"  [Error] Failed to convert {pytorch_model_path}: {e}\")\n        return False\n\n\n\n\n\n\n\n\nif __name__ == \"__main__\":\n    base_log_dir = \"your fine_tuned model path\"\n    total_converted = 0\n    total_skipped = 0\n    total_errors = 0\n\n    print(f\"Starting conversion process in base directory: {base_log_dir}\")\n\n\n    for training_run_dir_name in os.listdir(base_log_dir):\n        training_run_dir_path = os.path.join(base_log_dir, training_run_dir_name)\n        \n        # Check if it's actually a directory\n        if os.path.isdir(training_run_dir_path):\n            print(f\"\\nProcessing training run: {training_run_dir_name}\")\n            \n            # Iterate through items inside the training run directory\n            for checkpoint_dir_name in os.listdir(training_run_dir_path):\n     \n                if checkpoint_dir_name.startswith(\"checkpoint-\") and \\\n                   os.path.isdir(os.path.join(training_run_dir_path, checkpoint_dir_name)):\n                    \n                    checkpoint_dir_path = os.path.join(training_run_dir_path, checkpoint_dir_name)\n                    print(f\"- Found checkpoint directory: {checkpoint_dir_name}\")\n                    \n\n                    pytorch_model_path = os.path.join(checkpoint_dir_path, \"pytorch_model.bin\")\n                    output_ip_adapter_path = os.path.join(checkpoint_dir_path, \"ip_adapter.bin\")\n                    \n\n                    if os.path.exists(output_ip_adapter_path):\n                         print(f\"  Output file already exists, skipping: {output_ip_adapter_path}\")\n                         total_skipped += 1\n                         continue \n\n\n                    success = convert_checkpoint_to_ip_adapter(pytorch_model_path, output_ip_adapter_path)\n                    if success:\n                        total_converted += 1\n                    else:\n                       \n                         if not os.path.exists(pytorch_model_path):\n                            total_skipped += 1 \n                         else:\n                            total_errors +=1 \n\n    print(\"\\n--- Conversion Summary ---\")\n    print(f\"Total checkpoints converted: {total_converted}\")\n    print(f\"Total checkpoints skipped (e.g., source missing): {total_skipped}\")\n    print(f\"Total errors during conversion: {total_errors}\")\n"
  },
  {
    "path": "demo.py",
    "content": "import gradio as gr\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom PIL import Image\nfrom ip_adapter import IPAdapterXL\nfrom huggingface_hub import hf_hub_download\nimport os\nimport time\n\ntry:\n    from tutorial_train_sdxl_ori import ComposedAttention\nexcept ImportError:\n    print(\"Error: Could not import ComposedAttention.\")\n    print(\"Please ensure 'tutorial_train_sdxl_ori.py' is in the same directory as this script.\")\n    exit()\n\nprint(\"Loading models, please wait...\")\n\nCKPT_INTER_DIM = 2560\nCKPT_CROSS_HEADS = 8\nCKPT_RESHAPE_BLOCKS = 8\nCKPT_CROSS_VALUE_DIM = 64\n\nBASE_MODEL_PATH = \"/aigc_data_hdd/checkpoints/stable-diffusion-xl-base-1.0\"\nIMAGE_ENCODER_PATH = os.path.join(BASE_MODEL_PATH, \"image_encoder\")\n\nif not os.path.exists(BASE_MODEL_PATH) or not os.path.exists(IMAGE_ENCODER_PATH):\n    print(f\"Error: Model or image encoder path not found: {BASE_MODEL_PATH}\")\n    exit()\n\nIP_ADAPTER_REPO_ID = \"kkkkggg/IMAGHarmony\"\nIP_ADAPTER_FILENAME = \"IMAGHarmony_variant1.bin\"\n\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nprint(f\"Downloading weights file {IP_ADAPTER_FILENAME} from repository {IP_ADAPTER_REPO_ID}...\")\ntry:\n    fine_tuned_ckpt_path = hf_hub_download(\n        repo_id=IP_ADAPTER_REPO_ID,\n        filename=IP_ADAPTER_FILENAME,\n    )\n    print(f\"Weights file downloaded to: {fine_tuned_ckpt_path}\")\nexcept Exception as e:\n    print(f\"Failed to download weights: {e}\")\n    exit()\n\nprint(f\"Loading Stable Diffusion XL pipeline from local path: {BASE_MODEL_PATH}\")\npipe = StableDiffusionXLPipeline.from_pretrained(\n    BASE_MODEL_PATH,\n    torch_dtype=torch.float16,\n    add_watermarker=False,\n).to(DEVICE)\npipe.enable_vae_tiling()\n\nprint(\"Instantiating custom ComposedAttention module...\")\nnumber_class_crossattention = ComposedAttention(\n    image_hidden_size=1280,\n    text_context_dim=2048,\n    inter_dim=CKPT_INTER_DIM,\n    cross_heads=CKPT_CROSS_HEADS,\n    reshape_blocks=CKPT_RESHAPE_BLOCKS,\n    cross_value_dim=CKPT_CROSS_VALUE_DIM,\n    scale=1.0\n).to(DEVICE).half()\n\nprint(\"Extracting and loading ComposedAttention weights from the main checkpoint...\")\ntry:\n    state_dict = torch.load(fine_tuned_ckpt_path, map_location=\"cpu\")\n    composed_attention_weights = state_dict[\"composed_adapter\"]\n    number_class_crossattention.load_state_dict(composed_attention_weights)\n    print(\"Successfully loaded fine-tuned ComposedAttention weights.\")\nexcept KeyError:\n    print(f\"Error: Key 'composed_adapter' not found in weights file {IP_ADAPTER_FILENAME}.\")\n    exit()\nexcept Exception as e:\n    print(f\"An unknown error occurred while loading ComposedAttention weights: {e}\")\n    exit()\n\nprint(\"Initializing IP-Adapter...\")\nip_model = IPAdapterXL(\n    pipe,\n    IMAGE_ENCODER_PATH,\n    fine_tuned_ckpt_path,\n    DEVICE,\n    target_blocks=[\"down_blocks.1.attentions.1\"],\n    num_tokens=4,\n    inference=True,\n    number_class_crossattention=number_class_crossattention\n)\n\nprint(\"Models loaded. Gradio application is ready!\")\n\n\ndef generate_image(uploaded_image: Image.Image, local_path: str, save_path: str,\n                   prompt: str, extra_text: str, negative_prompt: str,\n                   guidance_scale: float, num_inference_steps: int, seed: int, progress=gr.Progress()):\n    \n    pil_image = None\n    if uploaded_image is not None:\n        pil_image = uploaded_image\n    elif local_path and local_path.strip():\n        try:\n            pil_image = Image.open(local_path.strip())\n        except FileNotFoundError:\n            raise gr.Error(f\"File not found. Please check the path: {local_path.strip()}\")\n        except Exception as e:\n            raise gr.Error(f\"Cannot open image file. Error: {e}\")\n    else:\n        raise gr.Error(\"Please upload a reference image or provide a valid local file path!\")\n\n    input_image = pil_image.resize((512, 512))\n    progress(0, desc=\"Generating image...\")\n\n    images = ip_model.generate(\n        pil_image=input_image,\n        prompt=prompt,\n        negative_prompt=negative_prompt,\n        scale=1.0,\n        guidance_scale=guidance_scale,\n        num_samples=1,\n        num_inference_steps=int(num_inference_steps),\n        seed=int(seed),\n        extra_text=extra_text,\n        number_class_crossattention=number_class_crossattention\n    )\n    generated_image = images[0]\n    progress(1, desc=\"Generation complete!\")\n\n    if save_path and save_path.strip():\n        try:\n            save_dir = save_path.strip()\n            os.makedirs(save_dir, exist_ok=True)\n            \n            timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n            filename = f\"output_{timestamp}_seed{seed}.png\"\n            full_path = os.path.join(save_dir, filename)\n            \n            generated_image.save(full_path)\n            gr.Info(f\"Image successfully saved to: {full_path}\")\n        except Exception as e:\n            gr.Warning(f\"Could not save the image! Error: {e}\")\n            print(f\"Error saving image: {e}\")\n\n    return generated_image\n\nwith gr.Blocks() as demo:\n    gr.Markdown(\"# IMAGHarmony: Image Generation Demo\")\n    gr.Markdown(\n        \"**Upload a reference image from your computer, or enter the full local path in the text box below.**\\n\"\n        \"Then, enter a **Target Prompt** and a **Reference Content** description to generate a new image.\"\n    )\n\n    with gr.Row():\n        with gr.Column(scale=1):\n            input_image = gr.Image(type=\"pil\", label=\"Upload Your Reference Image\")\n            local_path_input = gr.Textbox(\n                label=\"Or Enter Local Image Path\",\n                placeholder=\"/home/user/images/photo.jpg\",\n                info=\"If an image is uploaded, it will be prioritized over the path.\"\n            )\n            prompt = gr.Textbox(label=\"Target Prompt\", value=\"four cats\")\n            extra_text = gr.Textbox(\n                label=\"Reference Content\",\n                info=\"Enter text that describes the reference image, typically the caption used during training.\",\n                value=\"four dogs\"\n            )\n            neg_prompt = gr.Textbox(\n                label=\"Negative Prompt\",\n                value=\"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\"\n            )\n            save_path_input = gr.Textbox(\n                label=\"Save to Local Directory (Optional)\",\n                placeholder=\"/your/path\",\n                info=\"If left empty, the image will not be saved.\"\n            )\n            run_button = gr.Button(\"Generate Image\", variant=\"primary\")\n\n        with gr.Column(scale=1):\n            output_image = gr.Image(type=\"pil\", label=\"Generated Image\")\n\n    with gr.Accordion(\"Advanced Settings\", open=False):\n        guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=10.0, label=\"Guidance Scale\")\n        num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label=\"Inference Steps\")\n        seed = gr.Slider(minimum=0, maximum=999999, step=1, value=8, label=\"Seed\", randomize=True)\n\n    run_button.click(\n        fn=generate_image,\n        inputs=[input_image, local_path_input, save_path_input, prompt, extra_text, neg_prompt, guidance_scale, num_inference_steps, seed],\n        outputs=output_image\n    )\n\ndemo.launch(share=True)"
  },
  {
    "path": "ip_adapter/__init__.py",
    "content": "from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull\n\n\n__all__ = [\n    \"IPAdapter\",\n    \"IPAdapterPlus\",\n    \"IPAdapterPlusXL\",\n    \"IPAdapterXL\",\n    \"IPAdapterFull\",\n    \n]\n"
  },
  {
    "path": "ip_adapter/attention_processor.py",
    "content": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\n\nclass Cross_Attention(nn.Module):\n    def __init__(self,\n                 query_dim,         # Input dimension for Q projection\n                 context_dim,       # Input dimension for K/V projection\n                 heads=8,\n                 value_dim=None,    # Dimension of V after projection (defaults to head_dim)\n                 out_dim=None):     # Output dimension\n        super().__init__()\n        self.query_dim = query_dim\n        self.heads = heads\n        self.head_dim = self.query_dim // self.heads\n        self.scale = math.sqrt(self.head_dim)\n        self.value_dim = value_dim if value_dim is not None else self.head_dim\n        self.out_dim = out_dim if out_dim is not None else heads * self.value_dim\n\n        # Linear projection layers\n        self.to_q = nn.Linear(query_dim, self.heads * self.head_dim)\n        self.to_k = nn.Linear(context_dim, self.heads * self.head_dim)\n        self.to_v = nn.Linear(context_dim, self.heads * self.value_dim)\n\n        # Optional output projection\n        self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)\n\n    def forward(self, query_input, context_input):\n\n        B = query_input.size(0)\n\n        # Project Q, K, V\n        q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, Q_len, head_dim]\n        k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, K_len, head_dim]\n        v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2)  # [B, heads, K_len, v_dim]\n\n        # Attention scores (weights)\n        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # [B, heads, Q_len, K_len]\n        attn_probs = F.softmax(attn_scores, dim=-1)\n\n        # Weighted sum\n        attn_output = torch.matmul(attn_probs, v)  # [B, heads, Q_len, v_dim]\n\n        # Concatenate heads\n        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim)  # [B, Q_len, heads * v_dim]\n\n        # Output projection\n        output = self.out_proj(attn_output)  # [B, Q_len, out_dim]\n        return output\n    \n    \n    \nclass AttnProcessor(nn.Module):\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        *args,\n        **kwargs,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor(nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n        self.skip = skip\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        if not self.skip:\n            # for ip-adapter\n            ip_key = self.to_k_ip(ip_hidden_states)\n            ip_value = self.to_v_ip(ip_hidden_states)\n\n            ip_key = attn.head_to_batch_dim(ip_key)\n            ip_value = attn.head_to_batch_dim(ip_value)\n\n            ip_attention_probs = attn.get_attention_scores(query, ip_key, None)\n            self.attn_map = ip_attention_probs\n            ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)\n            ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)\n\n            hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass AttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        *args,\n        **kwargs,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater for PyTorch 2.0.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):\n        super().__init__()\n\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n        self.skip = skip\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if not self.skip:\n            # for ip-adapter\n            ip_key = self.to_k_ip(ip_hidden_states)\n            ip_value = self.to_v_ip(ip_hidden_states)\n\n            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            # the output of sdp = (batch, num_heads, seq_len, head_dim)\n            # TODO: add support for attn.scale when we move to Torch 2.1\n            ip_hidden_states = F.scaled_dot_product_attention(\n                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False\n            )\n            with torch.no_grad():\n                self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)\n                #print(self.attn_map.shape)\n\n            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n            ip_hidden_states = ip_hidden_states.to(query.dtype)\n\n            hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\n## for controlnet\nclass CNAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(self, num_tokens=4):\n        self.num_tokens = num_tokens\n\n    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass CNAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self, num_tokens=4):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.num_tokens = num_tokens\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        *args,\n        **kwargs,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n"
  },
  {
    "path": "ip_adapter/custom_pipelines.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg\n\nfrom .utils import is_torch2_available\n\nif is_torch2_available():\n    from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor\nelse:\n    from .attention_processor import IPAttnProcessor\n\n\nclass StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):\n    def set_scale(self, scale):\n        for attn_processor in self.unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                attn_processor.scale = scale\n\n    @torch.no_grad()\n    def __call__(  # noqa: C901\n        self,\n        prompt: Optional[Union[str, List[str]]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        control_guidance_start: float = 0.0,\n        control_guidance_end: float = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will ge generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.7):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            control_guidance_start (`float`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 7.1 Apply denoising_end\n        if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # get init conditioning scale\n        for attn_processor in self.unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                conditioning_scale = attn_processor.scale\n                break\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end):\n                    self.set_scale(0.0)\n                else:\n                    self.set_scale(conditioning_scale)\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if output_type != \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "ip_adapter/ip_adapter.py",
    "content": "import os\nfrom typing import List\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.controlnet import MultiControlNetModel\nfrom PIL import Image\nfrom safetensors import safe_open\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\nfrom  tutorial_train_sdxl_ori import HarmonyAttention\nfrom .utils import is_torch2_available, get_generator\n\nif is_torch2_available():\n    from .attention_processor import (\n        AttnProcessor2_0 as AttnProcessor,\n    )\n    from .attention_processor import (\n        CNAttnProcessor2_0 as CNAttnProcessor,\n    )\n    from .attention_processor import (\n        IPAttnProcessor2_0 as IPAttnProcessor,\n    )\nelse:\n    from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor\nfrom .resampler import Resampler\n\n\nclass ImageProjModel(torch.nn.Module):\n\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.generator = None\n        self.cross_attention_dim = cross_attention_dim  \n        self.clip_extra_context_tokens = clip_extra_context_tokens  \n \n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)  \n\n    def forward(self, image_embeds):\n \n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\nclass MLPProjModel(torch.nn.Module):\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):\n        super().__init__()\n        \n\n        self.proj = torch.nn.Sequential(\n            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),\n            torch.nn.GELU(),\n            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),\n            torch.nn.LayerNorm(cross_attention_dim)\n        )\n        \n    def forward(self, image_embeds):\n        clip_extra_context_tokens = self.proj(image_embeds)\n        return clip_extra_context_tokens\n\n\nclass IPAdapter:\n    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=None, number_class_crossattention=None):\n        self.device = device\n        self.image_encoder_path = image_encoder_path\n        self.ip_ckpt = ip_ckpt\n        self.num_tokens = num_tokens\n        # self.target_blocks = target_blocks or [\"down_blocks.2.attentions.1\"] \n\n        self.pipe = sd_pipe.to(self.device)\n        self.set_ip_adapter()\n\n        # load image encoder\n        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(\n            self.device, dtype=torch.float16\n        )\n        self.clip_image_processor = CLIPImageProcessor()\n        self.number_class_crossattention=number_class_crossattention.to(self.device, dtype=torch.float16)\n        # image proj model\n        self.image_proj_model = self.init_proj()\n\n        self.load_ip_adapter()\n\n    def init_proj(self):\n        image_proj_model = ImageProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.projection_dim,\n            clip_extra_context_tokens=self.num_tokens,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    def set_ip_adapter(self):\n        unet = self.pipe.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            \n            if cross_attention_dim is None:\n                attn_procs[name] = AttnProcessor()\n            else:\n\n                if 'down_blocks.2.attentions.1' in name:\n                    attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,\n                                                    num_tokens=self.num_tokens, skip=False).to(self.device, dtype=torch.float16)\n                    \n                else:\n                    attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,\n                                                    num_tokens=self.num_tokens, skip=True).to(self.device, dtype=torch.float16)\n        \n        unet.set_attn_processor(attn_procs)\n        \n\n        if hasattr(self.pipe, \"controlnet\"):\n            if isinstance(self.pipe.controlnet, MultiControlNetModel):\n                for controlnet in self.pipe.controlnet.nets:\n                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))\n            else:\n                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))\n    \n    def load_ip_adapter(self):\n        if os.path.splitext(self.ip_ckpt)[-1] == \".safetensors\":\n            state_dict = {\"image_proj_model\": {}, \"ip_adapter\": {}, 'composed_modules': {} }\n            with safe_open(self.ip_ckpt, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    # if 'unet.up_blocks' not in key  and  'unet.mid_blocks' not in key:\n                    #     print(key)\n                    if key.startswith(\"image_proj.\"):\n                        state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"adapter_modules.\"):\n                        state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"composed_modules.\"):\n                        state_dict[\"composed_modules\"][key.replace(\"composed_modules.\", \"\")] = f.get_tensor(key)\n        else:\n            state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n        print(state_dict.keys())\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"])\n        self.number_class_crossattention.load_state_dict(state_dict[\"composed_adapter\"])\n        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())\n        ip_layers.load_state_dict(state_dict[\"ip_adapter\"])\n        \n        \n        \n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None, extra_prompt_embeds=None):\n        if pil_image is not None:\n            if isinstance(pil_image, Image.Image):\n                pil_image = [pil_image]\n            clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n        else:\n            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n        \n        \n        # add composer\n        if extra_prompt_embeds is not None:\n            extra_prompt_embeds = extra_prompt_embeds.to(self.device,torch.float16)\n            output= self.number_class_crossattention(extra_prompt_embeds, clip_image_embeds)\n            clip_image_embeds = clip_image_embeds + output\n            \n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)   \n        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def set_scale(self, scale):\n        for attn_processor in self.pipe.unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                attn_processor.scale = scale\n\n    def generate(\n        self,\n        pil_image=None,\n        clip_image_embeds=None,\n        prompt=None,\n        negative_prompt=None,\n        scale=1.0,\n        num_samples=4,\n        seed=None,\n        guidance_scale=7.5,\n        num_inference_steps=30,\n        **kwargs,\n    ):\n        self.set_scale(scale)\n\n        if pil_image is not None:\n            num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n        else:\n            num_prompts = clip_image_embeds.size(0)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(\n            pil_image=pil_image, clip_image_embeds=clip_image_embeds\n        )\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(\n                prompt,\n                device=self.device,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\n\nclass IPAdapterXL(IPAdapter):\n    \"\"\"SDXL\"\"\"\n    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, \n                 num_tokens=4, target_blocks=None, inference=False, number_class_crossattention=None):\n        self.inference = inference  \n        super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, \n                        num_tokens=num_tokens, target_blocks=target_blocks, number_class_crossattention=number_class_crossattention)\n\n    def generate(\n        self,\n        pil_image,\n        prompt=None,\n        negative_prompt=None,\n        extra_text=None,\n        scale=1.0,\n        num_samples=4,\n        seed=None,\n        num_inference_steps=30,\n        **kwargs,\n    ):\n\n        self.set_scale(scale)\n\n        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n\n        if extra_text is not None:\n            with torch.inference_mode():\n                (\n                extra_prompt_embeds,\n                extra_negative_prompt_embeds,\n                extra_pooled_prompt_embeds,\n                extra_negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                extra_text,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )     \n            \n            \n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, extra_prompt_embeds = extra_prompt_embeds)\n\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n \n            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)\n\n\n\n            \n            \n        self.generator = get_generator(seed, self.device)\n        \n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=self.generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\n\nclass IPAdapterPlus(IPAdapter):\n\n\n    def init_proj(self):\n\n\n        image_proj_model = Resampler(\n            dim=self.pipe.unet.config.cross_attention_dim,\n            depth=4,\n            dim_head=64,\n            heads=12,\n            num_queries=self.num_tokens,\n            embedding_dim=self.image_encoder.config.hidden_size,\n            output_dim=self.pipe.unet.config.cross_attention_dim,\n            ff_mult=4,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):\n\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        clip_image = clip_image.to(self.device, dtype=torch.float16)\n        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_clip_image_embeds = self.image_encoder(\n            torch.zeros_like(clip_image), output_hidden_states=True\n        ).hidden_states[-2]\n        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n\nclass IPAdapterFull(IPAdapterPlus):\n    \"\"\"IP-Adapter with full features\"\"\"\n\n    def init_proj(self):\n        image_proj_model = MLPProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.hidden_size,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n\nclass IPAdapterPlusXL(IPAdapter):\n    \"\"\"SDXL\"\"\"\n\n    def init_proj(self):\n        image_proj_model = Resampler(\n            dim=1280,\n            depth=4,\n            dim_head=64,\n            heads=20,\n            num_queries=self.num_tokens,\n            embedding_dim=self.image_encoder.config.hidden_size,\n            output_dim=self.pipe.unet.config.cross_attention_dim,\n            ff_mult=4,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image):\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        clip_image = clip_image.to(self.device, dtype=torch.float16)\n        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_clip_image_embeds = self.image_encoder(\n            torch.zeros_like(clip_image), output_hidden_states=True\n        ).hidden_states[-2]\n        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def generate(\n        self,\n        pil_image,\n        prompt=None,\n        negative_prompt=None,\n        scale=1.0,\n        num_samples=4,\n        seed=None,\n        num_inference_steps=30,\n        **kwargs,\n    ):\n        self.set_scale(scale)\n\n        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n"
  },
  {
    "path": "ip_adapter/ip_adapter_origin.py",
    "content": "import os\nfrom typing import List\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.controlnet import MultiControlNetModel\nfrom PIL import Image\nfrom safetensors import safe_open\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom .utils import is_torch2_available, get_generator\n\nif is_torch2_available():\n    from .attention_processor import (\n        AttnProcessor2_0 as AttnProcessor,\n    )\n    from .attention_processor import (\n        CNAttnProcessor2_0 as CNAttnProcessor,\n    )\n    from .attention_processor import (\n        IPAttnProcessor2_0 as IPAttnProcessor,\n    )\nelse:\n    from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor\nfrom .resampler import Resampler\n\n\nclass ImageProjModel(torch.nn.Module):\n    \"\"\"投影模型 - 将CLIP图像特征转换为适合UNet交叉注意力的格式\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.generator = None\n        self.cross_attention_dim = cross_attention_dim  # UNet交叉注意力的维度\n        self.clip_extra_context_tokens = clip_extra_context_tokens  # 额外上下文token数量\n        # 线性投影层，将CLIP嵌入转换为多个额外的上下文token\n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)  # 标准化层\n\n    def forward(self, image_embeds):\n        # 投影CLIP图像嵌入到多个上下文token\n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\nclass MLPProjModel(torch.nn.Module):\n    \"\"\"使用多层感知器的投影模型 - 用于IPAdapterFull变体\"\"\"\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):\n        super().__init__()\n        \n        # 多层感知器，包含两个线性层、GELU激活和层标准化\n        self.proj = torch.nn.Sequential(\n            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),\n            torch.nn.GELU(),\n            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),\n            torch.nn.LayerNorm(cross_attention_dim)\n        )\n        \n    def forward(self, image_embeds):\n        clip_extra_context_tokens = self.proj(image_embeds)\n        return clip_extra_context_tokens\n\n\nclass IPAdapter:\n    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):\n        self.device = device\n        self.image_encoder_path = image_encoder_path\n        self.ip_ckpt = ip_ckpt\n        self.num_tokens = num_tokens\n\n        self.pipe = sd_pipe.to(self.device)\n        self.set_ip_adapter()\n\n        # load image encoder\n        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(\n            self.device, dtype=torch.float16\n        )\n        self.clip_image_processor = CLIPImageProcessor()\n        # image proj model\n        self.image_proj_model = self.init_proj()\n\n        self.load_ip_adapter()\n\n    def init_proj(self):\n        image_proj_model = ImageProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.projection_dim,\n            clip_extra_context_tokens=self.num_tokens,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    def set_ip_adapter(self):\n        unet = self.pipe.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None:\n                attn_procs[name] = AttnProcessor()\n            else:\n                attn_procs[name] = IPAttnProcessor(\n                    hidden_size=hidden_size,\n                    cross_attention_dim=cross_attention_dim,\n                    scale=1.0,\n                    num_tokens=self.num_tokens,\n                ).to(self.device, dtype=torch.float16)\n        unet.set_attn_processor(attn_procs)\n        if hasattr(self.pipe, \"controlnet\"):\n            if isinstance(self.pipe.controlnet, MultiControlNetModel):\n                for controlnet in self.pipe.controlnet.nets:\n                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))\n            else:\n                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))\n    \n    def load_ip_adapter(self):\n        if os.path.splitext(self.ip_ckpt)[-1] == \".safetensors\":\n            state_dict = {\"image_proj\": {}, \"ip_adapter\": {}}\n            with safe_open(self.ip_ckpt, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    if key.startswith(\"image_proj.\"):\n                        state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"ip_adapter.\"):\n                        state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n        else:\n            state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"])\n        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())\n        ip_layers.load_state_dict(state_dict[\"ip_adapter\"])\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):\n        if pil_image is not None:\n            if isinstance(pil_image, Image.Image):\n                pil_image = [pil_image]\n            clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n        else:\n            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def set_scale(self, scale):\n        for attn_processor in self.pipe.unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                attn_processor.scale = scale\n\n    def generate(\n        self,\n        pil_image=None,\n        clip_image_embeds=None,\n        prompt=None,\n        negative_prompt=None,\n        scale=1.0,\n        num_samples=4,\n        seed=None,\n        guidance_scale=7.5,\n        num_inference_steps=30,\n        **kwargs,\n    ):\n        self.set_scale(scale)\n\n        if pil_image is not None:\n            num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n        else:\n            num_prompts = clip_image_embeds.size(0)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(\n            pil_image=pil_image, clip_image_embeds=clip_image_embeds\n        )\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(\n                prompt,\n                device=self.device,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\n\nclass IPAdapterXL(IPAdapter):\n    \"\"\"SDXL\"\"\"\n\n    def generate(\n        self,\n        pil_image,\n        prompt=None,\n        negative_prompt=None,\n        scale=1.0,\n        num_samples=4,\n        seed=None,\n        num_inference_steps=30,\n        **kwargs,\n    ):\n        \"\"\"SDXL专用的生成方法，考虑了SDXL的pooled嵌入\"\"\"\n        self.set_scale(scale)\n\n        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)\n\n        self.generator = get_generator(seed, self.device)\n        \n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=self.generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\nclass IPAdapterPlus(IPAdapter):\n    \"\"\"使用细粒度特征的IP-Adapter增强版本\"\"\"\n\n    def init_proj(self):\n        \"\"\"使用Resampler替代简单的线性投影\"\"\"\n        # Resampler是一种更强大的特征重采样模型，能更好地捕获图像中的细节\n        image_proj_model = Resampler(\n            dim=self.pipe.unet.config.cross_attention_dim,\n            depth=4,\n            dim_head=64,\n            heads=12,\n            num_queries=self.num_tokens,\n            embedding_dim=self.image_encoder.config.hidden_size,\n            output_dim=self.pipe.unet.config.cross_attention_dim,\n            ff_mult=4,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):\n        \"\"\"使用CLIP模型的倒数第二层隐藏状态作为更丰富的特征\"\"\"\n        # 使用CLIP的内部特征而非最终投影，提供更细粒度的图像理解\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        clip_image = clip_image.to(self.device, dtype=torch.float16)\n        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_clip_image_embeds = self.image_encoder(\n            torch.zeros_like(clip_image), output_hidden_states=True\n        ).hidden_states[-2]\n        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n\nclass IPAdapterFull(IPAdapterPlus):\n    \"\"\"IP-Adapter with full features\"\"\"\n\n    def init_proj(self):\n        image_proj_model = MLPProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.hidden_size,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n\nclass IPAdapterPlusXL(IPAdapter):\n    \"\"\"SDXL\"\"\"\n\n    def init_proj(self):\n        image_proj_model = Resampler(\n            dim=1280,\n            depth=4,\n            dim_head=64,\n            heads=20,\n            num_queries=self.num_tokens,\n            embedding_dim=self.image_encoder.config.hidden_size,\n            output_dim=self.pipe.unet.config.cross_attention_dim,\n            ff_mult=4,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image):\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        clip_image = clip_image.to(self.device, dtype=torch.float16)\n        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_clip_image_embeds = self.image_encoder(\n            torch.zeros_like(clip_image), output_hidden_states=True\n        ).hidden_states[-2]\n        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def generate(\n        self,\n        pil_image,\n        prompt=None,\n        negative_prompt=None,\n        scale=1.0,\n        num_samples=4,\n        seed=None,\n        num_inference_steps=30,\n        **kwargs,\n    ):\n        self.set_scale(scale)\n\n        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n"
  },
  {
    "path": "ip_adapter/resampler.py",
    "content": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom einops.layers.torch import Rearrange\n\n\n# FFN\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head**-0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass Resampler(nn.Module):\n    def __init__(\n        self,\n        dim=1024,\n        depth=8,\n        dim_head=64,\n        heads=16,\n        num_queries=8,\n        embedding_dim=768,\n        output_dim=1024,\n        ff_mult=4,\n        max_seq_len: int = 257,  # CLIP tokens + CLS token\n        apply_pos_emb: bool = False,\n        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence\n    ):\n        super().__init__()\n        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None\n\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)\n\n        self.proj_in = nn.Linear(embedding_dim, dim)\n\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.to_latents_from_mean_pooled_seq = (\n            nn.Sequential(\n                nn.LayerNorm(dim),\n                nn.Linear(dim, dim * num_latents_mean_pooled),\n                Rearrange(\"b (n d) -> b n d\", n=num_latents_mean_pooled),\n            )\n            if num_latents_mean_pooled > 0\n            else None\n        )\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        FeedForward(dim=dim, mult=ff_mult),\n                    ]\n                )\n            )\n\n    def forward(self, x):\n        if self.pos_emb is not None:\n            n, device = x.shape[1], x.device\n            pos_emb = self.pos_emb(torch.arange(n, device=device))\n            x = x + pos_emb\n\n        latents = self.latents.repeat(x.size(0), 1, 1)\n\n        x = self.proj_in(x)\n\n        if self.to_latents_from_mean_pooled_seq:\n            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))\n            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)\n            latents = torch.cat((meanpooled_latents, latents), dim=-2)\n\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n\ndef masked_mean(t, *, dim, mask=None):\n    if mask is None:\n        return t.mean(dim=dim)\n\n    denom = mask.sum(dim=dim, keepdim=True)\n    mask = rearrange(mask, \"b n -> b n 1\")\n    masked_t = t.masked_fill(~mask, 0.0)\n\n    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)\n"
  },
  {
    "path": "ip_adapter/shared_models.py",
    "content": "import torch\nfrom torch import nn\nimport math\nimport os\nimport random\nimport argparse\nfrom pathlib import Path\nimport json\nimport itertools\nimport time\nimport torch.nn as nn\nfrom diffusers.models.attention_processor import Attention\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nclass Cross_Attention(nn.Module):\n    def __init__(self, \n                 query_dim,         # Q projection input dimension\n                 context_dim,       # K/V projection input dimension\n                 heads=8, \n                 head_dim=64, \n                 value_dim=None,    # V dimension after dimensionality reduction (defaults to head_dim)\n                 out_dim=None):     # Output dimension\n        super().__init__()\n        self.heads = heads\n        self.head_dim = head_dim\n        self.scale = math.sqrt(head_dim)\n        self.value_dim = value_dim if value_dim is not None else head_dim\n        self.out_dim = out_dim if out_dim is not None else heads * self.value_dim\n\n        # Linear projection layers\n        self.to_q = nn.Linear(query_dim, heads * head_dim)\n        self.to_k = nn.Linear(context_dim, heads * head_dim)\n        self.to_v = nn.Linear(context_dim, heads * self.value_dim)\n\n        # Optional output projection\n        self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)\n\n    def forward(self, query_input, context_input):\n        \"\"\"\n        query_input: [B, Q_len, query_dim]\n        context_input: [B, K_len, context_dim]\n        \"\"\"\n        B = query_input.size(0)\n\n        # Project Q, K, V\n        q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, Q_len, head_dim]\n        k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, K_len, head_dim]\n        v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2)  # [B, heads, K_len, v_dim]\n\n        # Attention weights\n        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # [B, heads, Q_len, K_len]\n        attn_probs = F.softmax(attn_scores, dim=-1)\n\n        # Weighted sum\n        attn_output = torch.matmul(attn_probs, v)  # [B, heads, Q_len, v_dim]\n\n        # Concatenate heads\n        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim)  # [B, Q_len, heads * v_dim]\n\n        # Output projection\n        output = self.out_proj(attn_output)  # [B, Q_len, out_dim]\n        return output\nclass ImageProjModel(torch.nn.Module):\n    \"\"\"Projection model - converts CLIP image features into a format suitable for UNet cross-attention\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.generator = None\n        self.cross_attention_dim = cross_attention_dim  # UNet cross-attention dimension\n        self.clip_extra_context_tokens = clip_extra_context_tokens  # Number of extra context tokens\n        # Linear projection layer, converts CLIP embeddings into multiple extra context tokens\n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)  # Normalization layer\n\n    def forward(self, image_embeds):\n        # Project CLIP image embeddings into multiple context tokens\n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\n\nclass Composed_Attention(torch.nn.Module):#Number_Class_crossAttention\n    def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):\n        super().__init__()\n        \n        # Cross Attention layer\n        self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32)\n        self.scale=scale\n        # self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64)\n        \n        # Image projection from 1280 to 2560\n        self.fc1=nn.Linear(hidden_size, hidden_size*2)\n        \n        \n        # Layer Normalization layer\n        self.ln = nn.LayerNorm(hidden_size)\n        \n        # FC layer 1280 to 1280\n        self.fc2 = nn.Linear(hidden_size, hidden_size)\n        \n\n    def forward(self, text_embeds,image_embeds):\n\n        image_embeds=self.fc1(image_embeds) #[1, 2560]\n      \n        image_embeds=image_embeds.reshape(1,4,640)\n        output = self.cross_attention(image_embeds, text_embeds) #[1,4,320]\n        output=output.reshape(1,1280)\n        \n        # Normalize the output using Layer Normalization\n        output=self.ln(output)\n        \n        #  FC layer [1,4,2048]->[1,4,2048]\n        output=self.fc2(output)\n        \n        return output*self.scale\n    \n    def load_from_checkpoint(self, ckpt_path: str):\n        from safetensors.torch import load_file\n        from collections import OrderedDict\n        \n        # Load weights file\n        weights = load_file(ckpt_path)\n        \n        # Initialize two dictionaries to store weights for different modules separately\n        image_proj_weights = OrderedDict()\n        attn_weights = OrderedDict()\n        \n        # Separate weights into different modules\n        for k, v in weights.items():\n            # Process image_proj_model weights (match two possible key name prefixes)\n            if k.startswith(\"image_proj_model.\") or k.startswith(\"image_proj.\"):\n                new_key = k.replace(\"image_proj_model.\", \"\").replace(\"image_proj.\", \"\")\n                if hasattr(self, \"image_proj_model\") and hasattr(self.image_proj_model, new_key.split('.')[0]):\n                    image_proj_weights[new_key] = v\n            \n            # Process target attention layer weights (match two possible key name formats)\n            elif \"down_blocks.2.attentions.1\" in k:\n                # Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1\n                new_key = k.replace(\"composed_modules.\", \"\").replace(\"ip_adapter.\", \"\")\n                if hasattr(self, new_key.split('.')[0]):\n                    attn_weights[new_key] = v\n        \n        # Load image_proj_model weights (strict mode)\n        if image_proj_weights:\n            self.image_proj_model.load_state_dict(image_proj_weights, strict=True)\n            print(f\"Loaded image_proj_model weights: {len(image_proj_weights)} params\")\n        \n        # Load attention layer weights (non-strict mode)\n        if attn_weights:\n            # Create temporary ModuleDict to load weights\n            temp_dict = {k: v for k, v in self.named_modules() \n                        if \"down_blocks.2.attentions.1\" in k}\n            temp_model = torch.nn.ModuleDict(temp_dict)\n            \n            missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False)\n            if missing:\n                print(f\"Missing keys in attention blocks: {missing}\")\n            if unexpected:\n                print(f\"Unexpected keys in attention blocks: {unexpected}\")\n            \n            print(f\"Loaded attention weights: {len(attn_weights)} params\")\n        \n        print(f\"Successfully loaded target modules from {ckpt_path}\")\n        return self\n"
  },
  {
    "path": "ip_adapter/test_resampler.py",
    "content": "import torch\nfrom resampler import Resampler\nfrom transformers import CLIPVisionModel\n\nBATCH_SIZE = 2\nOUTPUT_DIM = 1280\nNUM_QUERIES = 8\nNUM_LATENTS_MEAN_POOLED = 4  # 0 for no mean pooling (previous behavior)\nAPPLY_POS_EMB = True  # False for no positional embeddings (previous behavior)\nIMAGE_ENCODER_NAME_OR_PATH = \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\"\n\n\ndef main():\n    image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)\n    embedding_dim = image_encoder.config.hidden_size\n    print(f\"image_encoder hidden size: \", embedding_dim)\n\n    image_proj_model = Resampler(\n        dim=1024,\n        depth=2,\n        dim_head=64,\n        heads=16,\n        num_queries=NUM_QUERIES,\n        embedding_dim=embedding_dim,\n        output_dim=OUTPUT_DIM,\n        ff_mult=2,\n        max_seq_len=257,\n        apply_pos_emb=APPLY_POS_EMB,\n        num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,\n    )\n\n    dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)\n    with torch.no_grad():\n        image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]\n    print(\"image_embds shape: \", image_embeds.shape)\n\n    with torch.no_grad():\n        ip_tokens = image_proj_model(image_embeds)\n    print(\"ip_tokens shape:\", ip_tokens.shape)\n    assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "ip_adapter/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\nattn_maps = {}\ndef hook_fn(name):\n    def forward_hook(module, input, output):\n        if hasattr(module.processor, \"attn_map\"):\n            attn_maps[name] = module.processor.attn_map\n            del module.processor.attn_map\n\n    return forward_hook\n\ndef register_cross_attention_hook(unet):\n    for name, module in unet.named_modules():\n        if name.split('.')[-1].startswith('attn2'):\n            module.register_forward_hook(hook_fn(name))\n\n    return unet\n\ndef upscale(attn_map, target_size):\n    attn_map = torch.mean(attn_map, dim=0)\n    attn_map = attn_map.permute(1,0)\n    temp_size = None\n\n    for i in range(0,5):\n        scale = 2 ** i\n        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:\n            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))\n            break\n\n    assert temp_size is not None, \"temp_size cannot is None\"\n\n    attn_map = attn_map.view(attn_map.shape[0], *temp_size)\n\n    attn_map = F.interpolate(\n        attn_map.unsqueeze(0).to(dtype=torch.float32),\n        size=target_size,\n        mode='bilinear',\n        align_corners=False\n    )[0]\n\n    attn_map = torch.softmax(attn_map, dim=0)\n    return attn_map\ndef get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):\n\n    idx = 0 if instance_or_negative else 1\n    net_attn_maps = []\n\n    for name, attn_map in attn_maps.items():\n        attn_map = attn_map.cpu() if detach else attn_map\n        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()\n        attn_map = upscale(attn_map, image_size) \n        net_attn_maps.append(attn_map) \n\n    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)\n\n    return net_attn_maps\n\ndef attnmaps2images(net_attn_maps):\n\n    #total_attn_scores = 0\n    images = []\n\n    for attn_map in net_attn_maps:\n        attn_map = attn_map.cpu().numpy()\n        #total_attn_scores += attn_map.mean().item()\n\n        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255\n        normalized_attn_map = normalized_attn_map.astype(np.uint8)\n        #print(\"norm: \", normalized_attn_map.shape)\n        image = Image.fromarray(normalized_attn_map)\n\n        #image = fix_save_attn_map(attn_map)\n        images.append(image)\n\n    #print(total_attn_scores)\n    return images\ndef is_torch2_available():\n    return hasattr(F, \"scaled_dot_product_attention\")\n\ndef get_generator(seed, device):\n\n    if seed is not None:\n        if isinstance(seed, list):\n            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]\n        else:\n            generator = torch.Generator(device).manual_seed(seed)\n    else:\n        generator = None\n\n    return generator"
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==2.1.0\naccelerate==1.0.1\naiofiles==23.2.1\naiohappyeyeballs==2.4.4\naiohttp==3.10.11\naiosignal==1.3.1\nannotated-types==0.7.0\nanyio==3.7.1\nasync-timeout==5.0.1\nasyncer==0.0.2\nattrs==25.3.0\nbidict==0.23.1\nbitsandbytes==0.45.5\ncachetools==5.5.0\ncertifi==2024.8.30\nchainlit==1.1.402\ncharset-normalizer==3.4.0\nchevron==0.14.0\nclick==8.1.8\ncontourpy==1.1.1\ncycler==0.12.1\ndataclasses-json==0.5.14\ndatasets==3.1.0\nDeprecated==1.2.18\ndiffusers==0.30.0\ndill==0.3.8\ndistro==1.9.0\neinops==0.8.0\nexceptiongroup==1.2.2\nfastapi==0.110.3\nffmpy==0.5.0\nfilelock==3.16.1\nfiletype==1.2.0\nfonttools==4.57.0\nfrozenlist==1.5.0\nfsspec==2024.9.0\nftfy==6.2.3\ngoogle-auth==2.36.0\ngoogle-auth-oauthlib==1.0.0\ngoogleapis-common-protos==1.69.2\ngradio==4.44.1\ngradio_client==1.3.0\ngrpcio==1.67.1\nh11==0.14.0\nhttpcore==1.0.7\nhttpx==0.28.1\nhuggingface-hub==0.25.2\nidna==3.10\nimportlib_metadata==8.5.0\nimportlib_resources==6.4.5\nJinja2==3.1.4\njiter==0.9.0\nkiwisolver==1.4.7\nLazify==0.4.0\nliteralai==0.0.607\nloguru==0.7.3\nMarkdown==3.7\nmarkdown-it-py==3.0.0\nMarkupSafe==2.1.5\nmarshmallow==3.22.0\nmatplotlib==3.7.5\nmdurl==0.1.2\nmpmath==1.3.0\nmultidict==6.1.0\nmultiprocess==0.70.16\nmypy-extensions==1.0.0\nnest-asyncio==1.6.0\nnetworkx==3.1\nnumpy==1.24.4\nnvidia-cublas-cu12==12.1.3.1\nnvidia-cuda-cupti-cu12==12.1.105\nnvidia-cuda-nvrtc-cu12==12.1.105\nnvidia-cuda-runtime-cu12==12.1.105\nnvidia-cudnn-cu12==9.1.0.70\nnvidia-cufft-cu12==11.0.2.54\nnvidia-curand-cu12==10.3.2.106\nnvidia-cusolver-cu12==11.4.5.107\nnvidia-cusparse-cu12==12.1.0.106\nnvidia-nccl-cu12==2.20.5\nnvidia-nvjitlink-cu12==12.6.77\nnvidia-nvtx-cu12==12.1.105\noauthlib==3.2.2\nopenai==1.71.0\nopentelemetry-api==1.31.1\nopentelemetry-exporter-otlp==1.31.1\nopentelemetry-exporter-otlp-proto-common==1.31.1\nopentelemetry-exporter-otlp-proto-grpc==1.31.1\nopentelemetry-exporter-otlp-proto-http==1.31.1\nopentelemetry-instrumentation==0.52b1\nopentelemetry-proto==1.31.1\nopentelemetry-sdk==1.31.1\nopentelemetry-semantic-conventions==0.52b1\norjson==3.10.15\npackaging==23.2\npandas==2.0.3\npeft==0.13.2\npillow==10.4.0\npropcache==0.2.0\nprotobuf==5.28.3\npsutil==6.1.0\npyarrow==17.0.0\npyasn1==0.6.1\npyasn1_modules==0.4.1\npydantic==2.10.6\npydantic_core==2.27.2\npydub==0.25.1\nPygments==2.19.1\nPyJWT==2.9.0\npyparsing==3.1.4\npython-dateutil==2.9.0.post0\npython-dotenv==1.0.1\npython-engineio==4.11.2\npython-multipart==0.0.9\npython-socketio==5.12.1\npytz==2024.2\nPyYAML==6.0.2\nregex==2024.9.11\nrequests==2.32.3\nrequests-oauthlib==2.0.0\nrich==14.0.0\nrsa==4.9\nruff==0.11.8\nsafetensors==0.4.5\nscipy==1.10.1\nsemantic-version==2.10.0\nshellingham==1.5.4\nsimple-websocket==1.1.0\nsix==1.17.0\nsniffio==1.3.1\nsse-starlette==2.1.3\nstarlette==0.37.2\nsympy==1.13.3\nsyncer==2.0.3\ntensorboard==2.14.0\ntensorboard-data-server==0.7.2\ntensorboardX==2.6.2.2\ntimm==1.0.13\ntokenizers==0.20.3\ntomli==2.2.1\ntomlkit==0.12.0\ntorch==2.4.1\ntorchaudio==2.4.1\ntorchvision==0.19.1\ntqdm==4.66.5\ntransformers==4.45.0\ntriton==3.0.0\ntyper==0.15.3\ntyping-inspect==0.9.0\ntyping_extensions==4.12.2\ntzdata==2024.2\nuptrace==1.31.0\nurllib3==2.2.3\nuvicorn==0.25.0\nwatchfiles==0.20.0\nwcwidth==0.2.13\nwebsockets==12.0\nWerkzeug==3.0.6\nwrapt==1.17.2\nwsproto==1.2.0\nxformers==0.0.28.post1\nxxhash==3.5.0\nyarl==1.15.2\nzipp==3.20.2\n"
  },
  {
    "path": "run.sh",
    "content": "accelerate launch --gpu_ids 0 --num_processes 1 --mixed_precision \"fp16\" \\\n  train.py \\\n  --pretrained_model_name_or_path=\"your path\" \\\n  --pretrained_ip_adapter_path=\"your path\" \\\n  --image_encoder_path=\"your path\" \\\n  --data_root_path='your path' \\\n  --mixed_precision=\"fp16\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --dataloader_num_workers=4 \\\n  --learning_rate=2.5e-04 \\\n  --data_json_file=\"your path\" \\\n  --weight_decay=0.01 \\\n  --output_dir=\"your path\" \\\n  --save_steps=100 \\\n  --num_train_epochs 2100 \\\n  --composed_inter_dim=2560 \\\n  --composed_cross_heads=8 \\\n  --composed_reshape_blocks=8 \\\n  --composed_cross_value_dim=64\n"
  },
  {
    "path": "sdxl-fine-tuning/data/train.json",
    "content": "[\n  {\n  \"image_file\": \"your image\", \n  \"text\": \" \", \n  \"extra_text\": \"your caption\", \n  \"comments\": {\n      \"image_file\": \"five_cats.png\",\n      \"text\": \"\",\n      \"extra_text\": \"Five cats\"\n    }\n  }\n\n]\n"
  },
  {
    "path": "shared_models.py",
    "content": "import torch\nfrom torch import nn\nimport math\nimport os\nimport random\nimport argparse\nfrom pathlib import Path\nimport json\nimport itertools\nimport time\nimport torch.nn as nn\nfrom diffusers.models.attention_processor import Attention\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nclass Cross_Attention(nn.Module):\n    def __init__(self, \n                 query_dim,         # Q projection input dimension\n                 context_dim,       # K/V projection input dimension\n                 heads=8, \n                 head_dim=64, \n                 value_dim=None,    # V dimension after dimensionality reduction (defaults to head_dim)\n                 out_dim=None):     # Output dimension\n        super().__init__()\n        self.heads = heads\n        self.head_dim = head_dim\n        self.scale = math.sqrt(head_dim)\n        self.value_dim = value_dim if value_dim is not None else head_dim\n        self.out_dim = out_dim if out_dim is not None else heads * self.value_dim\n\n        # Linear projection layers\n        self.to_q = nn.Linear(query_dim, heads * head_dim)\n        self.to_k = nn.Linear(context_dim, heads * head_dim)\n        self.to_v = nn.Linear(context_dim, heads * self.value_dim)\n\n        # Optional output projection\n        self.out_proj = nn.Linear(heads * self.value_dim, self.out_dim)\n\n    def forward(self, query_input, context_input):\n        \"\"\"\n        query_input: [B, Q_len, query_dim]\n        context_input: [B, K_len, context_dim]\n        \"\"\"\n        B = query_input.size(0)\n\n        # Project Q, K, V\n        q = self.to_q(query_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, Q_len, head_dim]\n        k = self.to_k(context_input).view(B, -1, self.heads, self.head_dim).transpose(1, 2)  # [B, heads, K_len, head_dim]\n        v = self.to_v(context_input).view(B, -1, self.heads, self.value_dim).transpose(1, 2)  # [B, heads, K_len, v_dim]\n\n        # Attention weights\n        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # [B, heads, Q_len, K_len]\n        attn_probs = F.softmax(attn_scores, dim=-1)\n\n        # Weighted sum\n        attn_output = torch.matmul(attn_probs, v)  # [B, heads, Q_len, v_dim]\n\n        # Concatenate heads\n        attn_output = attn_output.transpose(1, 2).contiguous().view(B, -1, self.heads * self.value_dim)  # [B, Q_len, heads * v_dim]\n\n        # Output projection\n        output = self.out_proj(attn_output)  # [B, Q_len, out_dim]\n        return output\nclass ImageProjModel(torch.nn.Module):\n    \"\"\"Projection model - converts CLIP image features into a format suitable for UNet cross-attention\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.generator = None\n        self.cross_attention_dim = cross_attention_dim  # UNet cross-attention dimension\n        self.clip_extra_context_tokens = clip_extra_context_tokens  # Number of extra context tokens\n        # Linear projection layer, converts CLIP embeddings into multiple extra context tokens\n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)  # Normalization layer\n\n    def forward(self, image_embeds):\n        # Project CLIP image embeddings into multiple context tokens\n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\n\nclass Composed_Attention(torch.nn.Module):#Number_Class_crossAttention\n    def __init__(self, hidden_size=1280, cross_attention_dim=64, scale=1.0):\n        super().__init__()\n        \n        # Cross Attention layer\n        self.cross_attention = Cross_Attention(query_dim=640, context_dim=2048, heads=10, value_dim=32)\n        self.scale=scale\n        # self.cross_attention=Attention(query_dim=640, cross_attention_dim=2048, heads=10, dim_head=64)\n        \n        # Image projection from 1280 to 2560\n        self.fc1=nn.Linear(hidden_size, hidden_size*2)\n        \n        \n        # Layer Normalization layer\n        self.ln = nn.LayerNorm(hidden_size)\n        \n        # FC layer 1280 to 1280\n        self.fc2 = nn.Linear(hidden_size, hidden_size)\n        \n\n    def forward(self, text_embeds,image_embeds):\n\n        image_embeds=self.fc1(image_embeds) #[1, 2560]\n      \n        image_embeds=image_embeds.reshape(1,4,640)\n        output = self.cross_attention(image_embeds, text_embeds) #[1,4,320]\n        output=output.reshape(1,1280)\n        \n        # Normalize the output using Layer Normalization\n        output=self.ln(output)\n        \n        #  FC layer [1,4,2048]->[1,4,2048]\n        output=self.fc2(output)\n        \n        return output*self.scale\n    \n    def load_from_checkpoint(self, ckpt_path: str):\n        from safetensors.torch import load_file\n        from collections import OrderedDict\n        \n        # Load weights file\n        weights = load_file(ckpt_path)\n        \n        # Initialize two dictionaries to store weights for different modules separately\n        image_proj_weights = OrderedDict()\n        attn_weights = OrderedDict()\n        \n        # Separate weights into different modules\n        for k, v in weights.items():\n            # Process image_proj_model weights (match two possible key name prefixes)\n            if k.startswith(\"image_proj_model.\") or k.startswith(\"image_proj.\"):\n                new_key = k.replace(\"image_proj_model.\", \"\").replace(\"image_proj.\", \"\")\n                if hasattr(self, \"image_proj_model\") and hasattr(self.image_proj_model, new_key.split('.')[0]):\n                    image_proj_weights[new_key] = v\n            \n            # Process target attention layer weights (match two possible key name formats)\n            elif \"down_blocks.2.attentions.1\" in k:\n                # Convert key name format: composed_modules.down_blocks.2.attentions.1 to down_blocks.2.attentions.1\n                new_key = k.replace(\"composed_modules.\", \"\").replace(\"ip_adapter.\", \"\")\n                if hasattr(self, new_key.split('.')[0]):\n                    attn_weights[new_key] = v\n        \n        # Load image_proj_model weights (strict mode)\n        if image_proj_weights:\n            self.image_proj_model.load_state_dict(image_proj_weights, strict=True)\n            print(f\"Loaded image_proj_model weights: {len(image_proj_weights)} params\")\n        \n        # Load attention layer weights (non-strict mode)\n        if attn_weights:\n            # Create temporary ModuleDict to load weights\n            temp_dict = {k: v for k, v in self.named_modules() \n                        if \"down_blocks.2.attentions.1\" in k}\n            temp_model = torch.nn.ModuleDict(temp_dict)\n            \n            missing, unexpected = temp_model.load_state_dict(attn_weights, strict=False)\n            if missing:\n                print(f\"Missing keys in attention blocks: {missing}\")\n            if unexpected:\n                print(f\"Unexpected keys in attention blocks: {unexpected}\")\n            \n            print(f\"Loaded attention weights: {len(attn_weights)} params\")\n        \n        print(f\"Successfully loaded target modules from {ckpt_path}\")\n        return self\n"
  },
  {
    "path": "test.py",
    "content": "import torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom PIL import Image\nfrom ip_adapter import IPAdapterXL # Assuming IPAdapterXL is correctly defined in ip_adapter\nfrom train import HarmonyAttention # Assuming HarmonyAttention is correctly defined\nimport os\n\n\n'''\nThe following parameters must be manually adjusted to match the training parameters.\n'''\nckpt_inter_dim = 2560\nckpt_cross_heads = 8\nckpt_reshape_blocks = 8\nckpt_cross_value_dim = 64\n\n\n\n\n# Image generation function\ndef generate_image(input_path, prompt, extra_text, output_path=\"output.png\"):\n    print(f\"\\nGenerating image for: {prompt} + {extra_text}\")\n    \n    # Prepare input image\n    input_image = Image.open(input_path).resize((512, 512))\n    \n    # Generate image\n    images = ip_model.generate(\n        pil_image=input_image,\n        prompt=prompt,\n        negative_prompt=\"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n        scale=1.0,\n        guidance_scale=5.0,\n        num_samples=1,\n        num_inference_steps=30,\n        seed=42,\n        extra_text=extra_text,\n        number_class_crossattention=number_class_crossattention # This should be the HarmonyAttention instance\n    )\n    \n    # Save result\n    images[0].save(output_path)\n    print(f\"Saved generated image to: {output_path}\")\n    return images[0]\n\n\nif __name__ == \"__main__\":\n    input_image = \"your path to inputimage\" # Replace with your actual input image path\n    device = \"cuda:2\"  \n\n    # Model path configuration\n    base_model_path = \"your path\" # Replace with your actual base model path\n    image_encoder_path = \"your path\" # Replace with your actual image encoder path\n\n    fine_tuned_ckpt = \"fine_tuned model path\"  # Path to the fine-tuned weights (ip_adapter.bin or similar)\n\n    save_root = os.path.join(\n        'your path', # Replace with your desired save directory\n        fine_tuned_ckpt.split('/')[3] if len(fine_tuned_ckpt.split('/')) > 3 else \"default_folder1\", \n        fine_tuned_ckpt.split('/')[4] if len(fine_tuned_ckpt.split('/')) > 4 else \"default_folder2\", \n    )\n    # Create path (if it doesn't exist)\n    os.makedirs(save_root, exist_ok=True)\n    print(f\"Save directory created at: {save_root}\")\n    \n    # Load SDXL base model\n    print(\"Loading base SDXL model...\")\n    pipe = StableDiffusionXLPipeline.from_pretrained(\n        base_model_path,\n        torch_dtype=torch.float16,\n        add_watermarker=False,\n    )\n    pipe.enable_vae_tiling()\n    pipe.to(device)\n\n\n    # Define the fusion method to be used\n    fusion_method = \"cross_attention\"  # Options: \"cross_attention\", \"qformer\", \"mlp\"\n\n    # Initialize HarmonyAttention (this is the `number_class_crossattention` module)\n    print(\"Initializing HarmonyAttention module...\")\n    number_class_crossattention = HarmonyAttention( # Renamed for clarity, this is your custom attention module\n        image_hidden_size=1280,     # Keep unchanged or match your training\n        text_context_dim=2048,      # Keep unchanged or match your training\n        inter_dim=ckpt_inter_dim,\n        cross_heads=ckpt_cross_heads,\n        reshape_blocks=ckpt_reshape_blocks,\n        cross_value_dim=ckpt_cross_value_dim,\n        scale=1.0,                  # Keep unchanged or adjust as needed\n        fusion_method=fusion_method  # Add fusion method selection\n    ).to(device).half()\n\n    # Initialize IP-Adapter\n    print(\"Initializing IP-Adapter with target blocks...\")\n    ip_model = IPAdapterXL(\n        pipe, \n        image_encoder_path, \n        fine_tuned_ckpt, \n        device,\n        target_blocks=[\"down_blocks.2.attentions.1\"],  # Or your specific target blocks\n        num_tokens=4, # Or your specific number of tokens\n        inference=True,\n        number_class_crossattention=number_class_crossattention  # Pass the initialized HarmonyAttention module here\n    )\n\n\n    print(\"HarmonyAttention weights are expected to be loaded as part of the IP-Adapter checkpoint.\")\n \n\n    generate_image(\n        input_path=input_image,\n        prompt=\"lions\",\n        extra_text=\"eight sheep\", # Use the caption from training\n        output_path=os.path.join(save_root, os.path.basename(input_image)) # Safer way to construct output path\n    )\n"
  },
  {
    "path": "train.py",
    "content": "import os\nimport random\nimport argparse\nfrom pathlib import Path\nimport json\nimport itertools\nimport time\nimport torch.nn as nn\nfrom diffusers.models.attention_processor import Attention\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom collections import OrderedDict\nfrom torchvision import transforms\nfrom PIL import Image\nfrom transformers import CLIPImageProcessor\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\nfrom transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection\nfrom safetensors import safe_open\nfrom shared_models import ImageProjModel\nfrom baseline import QFormer,MLP,AttentionFusionWrapper\nfrom safetensors.torch import save_file, load_file\nfrom ip_adapter.utils import is_torch2_available\nif is_torch2_available():\n    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor\n\nelse:\n    from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor\nfrom ip_adapter.attention_processor import Cross_Attention\n\n\ndef count_model_params(model):\n    return sum([p.numel() for p in model.parameters()]) / 1e6\n\n# Dataset\nclass MyDataset(torch.utils.data.Dataset):\n\n    def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=\"\"):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.tokenizer_2 = tokenizer_2\n        self.size = size\n        self.center_crop = center_crop\n        self.i_drop_rate = i_drop_rate\n        self.t_drop_rate = t_drop_rate\n        self.ti_drop_rate = ti_drop_rate\n        self.image_root_path = image_root_path\n    \n        self.data = json.load(open(json_file)) # list of dict: [{\"image_file\": \"1.png\", \"text\": \"A dog\"}]\n\n        self.transform = transforms.Compose([\n            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ])\n\n        self.clip_image_processor = CLIPImageProcessor()\n        \n    def __getitem__(self, idx):\n        item = self.data[idx] \n        text = item[\"text\"]\n        text_extra = item['extra_text']\n        image_file = item[\"image_file\"]\n        \n        # read image\n        raw_image = Image.open(os.path.join(self.image_root_path, image_file))\n        \n        # original size\n        original_width, original_height = raw_image.size\n        original_size = torch.tensor([original_height, original_width])\n        \n        image_tensor = self.transform(raw_image.convert(\"RGB\"))\n        # random crop\n        delta_h = image_tensor.shape[1] - self.size\n        delta_w = image_tensor.shape[2] - self.size\n        assert not all([delta_h, delta_w])\n        \n        if self.center_crop:\n            top = delta_h // 2\n            left = delta_w // 2\n        else:\n            top = np.random.randint(0, delta_h + 1)\n            left = np.random.randint(0, delta_w + 1)\n        image = transforms.functional.crop(\n            image_tensor, top=top, left=left, height=self.size, width=self.size\n        )\n        crop_coords_top_left = torch.tensor([top, left]) \n\n        clip_image = self.clip_image_processor(images=raw_image, return_tensors=\"pt\").pixel_values\n        \n        # drop\n        drop_image_embed = 0\n        rand_num = random.random()\n        if rand_num < self.i_drop_rate:\n            drop_image_embed = 1\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate):\n            text = \"\"\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):\n            text = \"\"\n            drop_image_embed = 1\n\n        # get text and tokenize\n        text_input_ids = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        \n        text_input_ids_2 = self.tokenizer_2(\n            text,\n            max_length=self.tokenizer_2.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        \n        #extra text\n        text_extra_input_ids=self.tokenizer(\n            text_extra,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        \n        text_extra_input_ids_2 = self.tokenizer_2(\n            text_extra,\n            max_length=self.tokenizer_2.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        \n        return {\n            \"image\": image,\n            \"text_input_ids\": text_input_ids,\n            \"text_input_ids_2\": text_input_ids_2,\n            \"text_extra_input_ids\":text_extra_input_ids,\n            \"text_extra_input_ids_2\":text_extra_input_ids_2,\n            \"clip_image\": clip_image,\n            \"drop_image_embed\": drop_image_embed,\n            \"original_size\": original_size,\n            \"crop_coords_top_left\": crop_coords_top_left,\n            \"target_size\": torch.tensor([self.size, self.size]),\n\n        }\n        \n    \n    def __len__(self):\n        return len(self.data)\n    \n\ndef collate_fn(data):\n    images = torch.stack([example[\"image\"] for example in data])\n    text_input_ids = torch.cat([example[\"text_input_ids\"] for example in data], dim=0)\n    text_input_ids_2 = torch.cat([example[\"text_input_ids_2\"] for example in data], dim=0)\n    clip_images = torch.cat([example[\"clip_image\"] for example in data], dim=0)\n    drop_image_embeds = [example[\"drop_image_embed\"] for example in data]\n    original_size = torch.stack([example[\"original_size\"] for example in data])\n    crop_coords_top_left = torch.stack([example[\"crop_coords_top_left\"] for example in data])\n    target_size = torch.stack([example[\"target_size\"] for example in data])\n\n    #extra text\n    text_extra_input_ids=torch.stack([example[\"text_extra_input_ids\"] for example in data])\n    text_extra_input_ids_2=torch.stack([example[\"text_extra_input_ids_2\"] for example in data])\n    \n    return {\n        \"images\": images,\n        \"text_input_ids\": text_input_ids,\n        \"text_input_ids_2\": text_input_ids_2,\n        \"clip_images\": clip_images,\n        \"drop_image_embeds\": drop_image_embeds,\n        \"original_size\": original_size,\n        \"crop_coords_top_left\": crop_coords_top_left,\n        \"target_size\": target_size,\n        \"text_extra_input_ids\":text_extra_input_ids,\n        \"text_extra_input_ids_2\":text_extra_input_ids_2,\n    }\n\n\n\nclass HarmonyAttention(nn.Module):\n    def __init__(self,\n                 image_hidden_size=1280,     # Input image feature dimension\n                 text_context_dim=2048,      # Input text context feature dimension\n                 inter_dim=2560,             # Intermediate projection dimension\n                 cross_heads=10,             # Number of cross-attention heads\n                 reshape_blocks=8,           # Number of image feature blocks\n                 cross_value_dim=64,         # Value dimension per head after dimensionality reduction\n                 scale=1.0,                  # Output scaling factor\n                 fusion_method=\"qformer\"): # Fusion method selection: mlp, cross_attention, qformer\n        super().__init__()\n        \n        self.scale = scale\n        self.reshape_blocks = reshape_blocks\n        self.cross_query_dim = inter_dim // reshape_blocks\n        self.fusion_method = fusion_method\n        self.image_hidden_size = image_hidden_size\n        self.text_context_dim = text_context_dim\n        \n        # Image projection required by all methods\n        self.fc1 = nn.Linear(image_hidden_size, inter_dim)\n        print(fusion_method)\n        if fusion_method == \"cross_attention\":\n            # 1. Cross-attention\n            self.fusion_text_image = Cross_Attention(\n                query_dim=self.cross_query_dim,\n                context_dim=text_context_dim,\n                heads=cross_heads,\n                value_dim=cross_value_dim\n            )\n        elif fusion_method == \"qformer\":\n            # 2. Q-Former TODO\n            self.fusion_text_image = QFormer(\n                 num_queries=16,\n                 hidden_dim=self.cross_query_dim,\n                 num_layers=1,\n                 num_heads=cross_heads\n            )\n        elif fusion_method=='mlp': \n           #3. MLP TODO\n           self.fusion_text_image = MLP(\n               fused_dim=self.cross_query_dim\n            )\n        elif fusion_method=='gated-attention':\n            #4. Gated-attention\n            self.fusion_text_image = AttentionFusionWrapper(\n               fused_dim=self.cross_query_dim\n            )\n     \n        flattened_dim = cross_value_dim * cross_heads * reshape_blocks\n        self.ln = nn.LayerNorm(flattened_dim)\n        self.fc2 = nn.Linear(flattened_dim, image_hidden_size)\n\n\n\n    def forward(self, text_embeds, image_embeds):\n            \"\"\"\n            Args:\n                text_embeds: [B, T, text_context_dim]\n                image_embeds: [B, image_hidden_size]\n            Returns:\n                out: [B, image_hidden_size], image features fused with text conditions\n            \"\"\"\n            B = image_embeds.size(0)\n\n            # Map image features to intermediate dimension and reshape into multiple blocks\n            x = self.fc1(image_embeds)  # [B, inter_dim]\n            x = x.view(B, self.reshape_blocks, self.cross_query_dim)  # [B, N_blocks, query_dim]  \n\n            # Cross-Attention: interaction between image blocks and text\n            print(self.fusion_text_image)\n            attended = self.fusion_text_image(x, text_embeds)  # [B, N_blocks, value_dim * heads]\n            print(attended.shape)\n            # Flatten, normalize, and project back\n            attended = attended.view(B, -1)     # [B, flattened_dim]\n            out = self.ln(attended)\n            out = self.fc2(out) * self.scale\n\n            return out\n\n\n    \n\n  \n\n           \n           \nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None, \n                 inter_dim=None,\n                 cross_heads=None,\n                 reshape_blocks=None,\n                 cross_value_dim=None,\n                 fusion_method=\"cross_attention\"):  # Fusion method parameter\n        super().__init__()\n        self.unet = unet\n        self.image_proj_model = image_proj_model\n        self.adapter_modules = adapter_modules\n        \n        # Directly use the fusion_method parameter, no need to convert to a boolean flag\n        self.composed_modules = HarmonyAttention(\n            image_hidden_size=1280,     # Image feature dimension fixed\n            text_context_dim=2048,      # Text context dimension fixed\n            inter_dim=inter_dim,\n            cross_heads=cross_heads,\n            reshape_blocks=reshape_blocks,\n            cross_value_dim=cross_value_dim,\n            scale=1.0,                  # Scaling factor fixed\n            fusion_method=fusion_method  # Directly pass the fusion method parameter\n        )\n        \n        if ckpt_path is not None:\n            self.load_from_checkpoint(ckpt_path)\n    \n    def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds, text_extra_embeds):\n        composed_embeds = self.composed_modules(text_extra_embeds, image_embeds)\n        image_embeds = image_embeds + composed_embeds\n        \n        ip_tokens = self.image_proj_model(image_embeds)\n        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)\n        # Predict the noise residual\n        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample\n        return noise_pred\n\n    def load_from_checkpoint(self, ckpt_path: str):\n        # Calculate original checksums\n        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        if os.path.splitext(ckpt_path)[-1] == \".safetensors\":\n            state_dict = {\"image_proj\": {}, \"ip_adapter\": {}}\n            with safe_open(ckpt_path, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    if key.startswith(\"image_proj.\"):\n                        state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"ip_adapter.\"):\n                        state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n        else:\n            state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n\n        # Load state dict for image_proj_model and adapter_modules\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=True)\n        self.adapter_modules.load_state_dict(state_dict[\"ip_adapter\"], strict=False)\n\n        # Calculate new checksums\n        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n        \n        # Verify if the weights have changed\n        assert orig_ip_proj_sum != new_ip_proj_sum, \"Weights of image_proj_model did not change!\"\n        assert orig_adapter_sum != new_adapter_sum, \"Weights of adapter_modules did not change!\"\n\n        print(f\"Successfully loaded weights from checkpoint {ckpt_path}\")\n    \n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_ip_adapter_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained ip adapter model. If not specified weights are initialized randomly.\",\n    )\n    parser.add_argument(\n        \"--data_json_file\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Training data\",\n    )\n    parser.add_argument(\n        \"--data_root_path\",\n        type=str,\n        default=\"\",\n        required=True,\n        help=\"Training data root path\",\n    )\n    parser.add_argument(\n        \"--image_encoder_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to CLIP image encoder\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-ip_adapter\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images\"\n        ),\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Learning rate to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=10000)\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=8, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=None, help=\"noise offset\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=2000,\n        help=(\n            \"Save a checkpoint of the training state every X updates\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--composed_inter_dim\",\n        type=int,\n        default=None,\n        help=\"HarmonyAttention's intermediate projection dimension [e.g., 1280, 2560].\"\n    )\n    parser.add_argument(\n        \"--composed_cross_heads\",\n        type=int,\n        default=None,\n        help=\"Number of cross-attention heads in HarmonyAttention [e.g., 8, 10].\"\n    )\n    parser.add_argument(\n        \"--composed_reshape_blocks\",\n        type=int,\n        default=None,\n        help=\"Number of image feature blocks in HarmonyAttention [e.g., 4, 8].\"\n    )\n    parser.add_argument(\n        \"--composed_cross_value_dim\",\n        type=int,\n        default=None,\n        help=\"Value dimension per head after dimensionality reduction in HarmonyAttention [e.g., 32, 64].\"\n    )\n    \n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n    \n\ndef main():\n    # Parse command-line arguments\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    # Configure accelerate for mixed-precision and distributed training\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n    \n    # Create output directory\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load various pre-trained model components\n    # Noise addition control, tokenizer, text encoder, VAE, UNet, etc.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer_2\")\n    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)\n    \n    # Freeze base model parameters, only train the adapter part\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    text_encoder_2.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n    \n    # Initialize IP-Adapter model components\n    # Image projection model maps CLIP image embeddings to UNet's cross-attention dimension\n    num_tokens = 4  # Number of extra context tokens\n    image_proj_model = ImageProjModel(\n        cross_attention_dim=unet.config.cross_attention_dim,\n        clip_embeddings_dim=image_encoder.config.projection_dim,\n        clip_extra_context_tokens=num_tokens,\n    )\n    \n    # Initialize adapter attention processors\n    # Create appropriate attention processors for each attention block in UNet\n    # init adapter modules\n    attn_procs = {}\n    unet_sd = unet.state_dict()\n    # Initialize UNet's attention + IP-Adapter's cross_attention parameters\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n            \n        if cross_attention_dim is None:\n            attn_procs[name] = AttnProcessor()\n        else:\n            layer_name = name.split(\".processor\")[0]\n            # Layers that need to add IP join additional attention parameters\n            if 'down_blocks.2.attentions.1' in name:\n        \n                weights = {\n                    \"to_k_ip.weight\": unet_sd[layer_name + \".to_k.weight\"],\n                    \"to_v_ip.weight\": unet_sd[layer_name + \".to_v.weight\"],\n                }  \n                attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,\n                                                   num_tokens=num_tokens, skip=False)\n                \n                attn_procs[name].load_state_dict(weights, strict=False)\n            else:\n                attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,\n                                                   num_tokens=num_tokens, skip=True)\n\n    # Load all attention processing into UNet\n    unet.set_attn_processor(attn_procs)\n    # Extract all attention layers into a list\n    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n\n\n    # Create the complete IP-Adapter model\n    ip_adapter = IPAdapter(\n        unet,\n        image_proj_model,\n        adapter_modules,\n        args.pretrained_ip_adapter_path,\n        inter_dim=args.composed_inter_dim,           # Use command-line arguments\n        cross_heads=args.composed_cross_heads,         # Use command-line arguments\n        reshape_blocks=args.composed_reshape_blocks,   # Use command-line arguments\n        cross_value_dim=args.composed_cross_value_dim, # Use command-line arguments\n        fusion_method='cross_attention'\n    )\n\n    # ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)\n    # Set data type for mixed-precision training\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        \n    # Move models to the appropriate device and convert to the appropriate data type\n    vae.to(accelerator.device)  # VAE uses fp32 for better stability\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n    \n    # Set optimizer - only optimize IP-Adapter components\n    params_to_opt = itertools.chain(ip_adapter.adapter_modules.parameters(), ip_adapter.composed_modules.parameters())\n    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)\n    accelerator.print(\"Trainable parameters:  adapter_modules:{:.2f}M, composed_modules:{:.2f}M\".format(\n    count_model_params(ip_adapter.adapter_modules),\n    count_model_params(ip_adapter.composed_modules)))\n    # Create dataset and dataloader\n    train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, image_root_path=args.data_root_path)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n    \n    # Prepare model, optimizer, and dataloader with accelerator\n    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)\n    \n    # Start training loop\n    global_step = 0\n    for epoch in range(0, args.num_train_epochs):\n        begin = time.perf_counter()\n        for step, batch in enumerate(train_dataloader):\n            load_data_time = time.perf_counter() - begin\n            with accelerator.accumulate(ip_adapter):\n                # Convert images to latent space using VAE\n                with torch.no_grad():\n                    # SDXL's VAE uses fp32 for better numerical stability\n                    latents = vae.encode(batch[\"images\"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample()\n                    latents = latents * vae.config.scaling_factor\n                    latents = latents.to(accelerator.device, dtype=weight_dtype)\n\n                # Generate random noise to add to the latent representation\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # Use noise offset technique to improve training stability\n                    noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(accelerator.device, dtype=weight_dtype)\n\n                bsz = latents.shape[0]\n                # Sample random timesteps for each image\n                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise according to timesteps (forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n            \n                # Get CLIP image embeddings\n                with torch.no_grad():\n                    image_embeds = image_encoder(batch[\"clip_images\"].to(accelerator.device, dtype=weight_dtype)).image_embeds\n                \n                # Apply image embedding dropout strategy\n                image_embeds_ = []\n                for image_embed, drop_image_embed in zip(image_embeds, batch[\"drop_image_embeds\"]):\n                    if drop_image_embed == 1:\n                        image_embeds_.append(torch.zeros_like(image_embed))\n                    else:\n                        image_embeds_.append(image_embed)\n                image_embeds = torch.stack(image_embeds_)\n            \n                # Get text embeddings (SDXL uses two text encoders)\n                with torch.no_grad():\n                    encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True)\n                    text_embeds = encoder_output.hidden_states[-2]\n                    encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True)\n                    pooled_text_embeds = encoder_output_2[0]\n                    text_embeds_2 = encoder_output_2.hidden_states[-2]\n                    text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1)  # Concatenate outputs of the two text encoders\n                    \n                    #extra text\n                    encoder_extra_output = text_encoder(batch['text_extra_input_ids'].to(accelerator.device),output_hidden_states=True)\n                    text_extra_embeds = encoder_extra_output.hidden_states[-2]\n                    encoder_extra_output_2 = text_encoder_2(batch['text_extra_input_ids_2'].to(accelerator.device),output_hidden_states=True)\n                    text_extra_embeds_2 = encoder_extra_output_2.hidden_states[-2]\n                    text_extra_embeds=torch.concat([text_extra_embeds,text_extra_embeds_2],dim=-1)\n                                        \n                # Add extra conditions required by SDXL (image size and crop info)\n                add_time_ids = [\n                    batch[\"original_size\"].to(accelerator.device),\n                    batch[\"crop_coords_top_left\"].to(accelerator.device),\n                    batch[\"target_size\"].to(accelerator.device),\n                ]\n                add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)\n                unet_added_cond_kwargs = {\"text_embeds\": pooled_text_embeds, \"time_ids\": add_time_ids}\n                \n                # Predict noise using IP-Adapter\n                noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds, text_extra_embeds)\n                \n                # Calculate MSE loss (between predicted noise and actual added noise)\n                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction=\"mean\")\n            \n                # Gather loss in distributed training\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()\n                \n                # Backpropagation and optimizer step\n                accelerator.backward(loss)\n                optimizer.step()\n                optimizer.zero_grad()\n\n                # Print training information\n                if accelerator.is_main_process:\n                    print(\"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}\".format(\n                        epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))\n            \n            global_step += 1\n            \n            # Save checkpoint periodically\n            if global_step % args.save_steps == 0:\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                accelerator.save_state(save_path, safe_serialization=False)\n            \n            begin = time.perf_counter()\n                \nif __name__ == \"__main__\":\n    main()\n"
  }
]