[
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<h1>CSGO: Content-Style Composition in Text-to-Image Generation</h1>\n\n[**Peng Xing**](https://github.com/xingp-ng)<sup>12*</sup> · [**Haofan Wang**](https://haofanwang.github.io/)<sup>1*</sup> · [**Yanpeng Sun**](https://scholar.google.com.hk/citations?user=a3FI8c4AAAAJ&hl=zh-CN&oi=ao/)<sup>2</sup> · [**Qixun Wang**](https://github.com/wangqixun)<sup>1</sup> · [**Xu Bai**](https://huggingface.co/baymin0220)<sup>13</sup> · [**Hao Ai**](https://github.com/aihao2000)<sup>14</sup> · [**Renyuan Huang**](https://github.com/DannHuang)<sup>15</sup>\n[**Zechao Li**](https://zechao-li.github.io/)<sup>2✉</sup>\n\n<sup>1</sup>InstantX Team · <sup>2</sup>Nanjing University of Science and Technology · <sup>3</sup>Xiaohongshu  · <sup>4</sup>Beihang University · <sup>5</sup>Peking University\n\n<sup>*</sup>equal contributions, <sup>✉</sup>corresponding authors\n\n<a href='https://csgo-gen.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>\n<a href='https://arxiv.org/abs/2404.02733'><img src='https://img.shields.io/badge/Technique-Report-red'></a>\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/InstantX/CSGO)\n[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-App-red)](https://huggingface.co/spaces/xingpng/CSGO/)\n[![GitHub](https://img.shields.io/github/stars/instantX-research/CSGO?style=social)](https://github.com/instantX-research/CSGO)\n</div>\n\n\n##  Updates 🔥\n\n[//]: # (- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing &#40;aka v2v&#41;! More to see [here]&#40;assets/docs/changelog/2024-07-19.md&#41;.)\n\n[//]: # (- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu]&#40;https://github.com/jeethu&#41;'s PR [#143]&#40;https://github.com/KwaiVGI/LivePortrait/pull/143&#41;.)\n\n[//]: # (- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here]&#40;assets/docs/changelog/2024-07-10.md&#41;.)\n\n[//]: # (- **`2024/07/09`**: 🤗 We released the [HuggingFace Space]&#40;https://huggingface.co/spaces/KwaiVGI/liveportrait&#41;, thanks to the HF team and [Gradio]&#40;https://github.com/gradio-app/gradio&#41;!)\n[//]: # (Continuous updates, stay tuned!)\n- **`2024/09/04`**: 🔥 We released the gradio code. You can simply configure it and use it directly.\n- **`2024/09/03`**: 🔥 We released the online demo on [Hugggingface](https://huggingface.co/spaces/xingpng/CSGO/).\n- **`2024/09/03`**: 🔥 We released the [pre-trained weight](https://huggingface.co/InstantX/CSGO).\n- **`2024/09/03`**: 🔥 We released the initial version of the inference code.\n- **`2024/08/30`**: 🔥 We released the technical report on [arXiv](https://arxiv.org/pdf/2408.16766)\n- **`2024/07/15`**: 🔥 We released the [homepage](https://csgo-gen.github.io).\n\n##   Plan 💪\n- [x]  technical report\n- [x]  inference code\n- [x]  pre-trained weight [4_16]\n- [x]  pre-trained weight [4_32]\n- [x]  online demo\n- [ ]  pre-trained weight_v2 [4_32]\n- [ ]  IMAGStyle dataset\n- [ ]  training code\n- [ ]  more pre-trained weight \n\n## Introduction 📖\nThis repo, named **CSGO**, contains the official PyTorch implementation of our paper [CSGO: Content-Style Composition in Text-to-Image Generation](https://arxiv.org/pdf/).\nWe are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.\n\n## Pipeline \t💻\n<p align=\"center\">\n  <img src=\"assets/image3_1.jpg\">\n</p>\n\n## Capabilities 🚅 \n\n  🔥 Our CSGO achieves **image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis**.\n\n  🔥 For more results, visit our <a href=\"https://csgo-gen.github.io\"><strong>homepage</strong></a> 🔥\n\n<p align=\"center\">\n  <img src=\"assets/vis.jpg\">\n</p>\n\n\n## Getting Started 🏁\n### 1. Clone the code and prepare the environment\n```bash\ngit clone https://github.com/instantX-research/CSGO\ncd CSGO\n\n# create env using conda\nconda create -n CSGO python=3.9\nconda activate CSGO\n\n# install dependencies with pip\n# for Linux and Windows users\npip install -r requirements.txt\n```\n\n### 2. Download pretrained weights\n\nWe currently release two model weights.\n\n|       Mode       | content token | style token |               Other               |\n|:----------------:|:-----------:|:-----------:|:---------------------------------:|\n|     csgo.bin     |4|16|                 -                 |\n|  csgo_4_32.bin   |4|32|          Deepspeed zero2          |\n| csgo_4_32_v2.bin |4|32| Deepspeed zero2+more(coming soon) |\n\nThe easiest way to download the pretrained weights is from HuggingFace:\n```bash\n# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage\ngit lfs install\n# clone and move the weights\ngit clone https://huggingface.co/InstantX/CSGO\n```\nOur method is fully compatible with [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix), [ControlNet](https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic), and [Image Encoder](https://huggingface.co/h94/IP-Adapter/tree/main/sdxl_models/image_encoder).\nPlease download them and place them in the ./base_models folder.\n\ntips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following:\n```bash\ngit clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic\nmv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors\n```\n### 3. Inference 🚀\n\n```python\nimport torch\nfrom ip_adapter.utils import BLOCKS as BLOCKS\nfrom ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS\nfrom PIL import Image\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    StableDiffusionXLControlNetPipeline,\n\n)\nfrom ip_adapter import CSGO\n\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nbase_model_path =  \"./base_models/stable-diffusion-xl-base-1.0\"  \nimage_encoder_path = \"./base_models/IP-Adapter/sdxl_models/image_encoder\"\ncsgo_ckpt = \"./CSGO/csgo.bin\"\npretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix'\ncontrolnet_path = \"./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic\"\nweight_dtype = torch.float16\n\n\nvae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_path,\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    add_watermarker=False,\n    vae=vae\n)\npipe.enable_vae_tiling()\n\n\ntarget_content_blocks = BLOCKS['content']\ntarget_style_blocks = BLOCKS['style']\ncontrolnet_target_content_blocks = controlnet_BLOCKS['content']\ncontrolnet_target_style_blocks = controlnet_BLOCKS['style']\n\ncsgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,\n                          target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet=False,controlnet_adapter=True,\n                              controlnet_target_content_blocks=controlnet_target_content_blocks, \n                              controlnet_target_style_blocks=controlnet_target_style_blocks,\n                              content_model_resampler=True,\n                              style_model_resampler=True,\n                              load_controlnet=False,\n\n                              )\n\nstyle_name = 'img_0.png'\ncontent_name = 'img_0.png'\nstyle_image = \"../assets/{}\".format(style_name)\ncontent_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')\n\ncaption ='a small house with a sheep statue on top of it'\n\nnum_sample=4\n\n#image-driven style transfer\nimages = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,\n                           prompt=caption,\n                           negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n                           content_scale=1.0,\n                           style_scale=1.0,\n                           guidance_scale=10,\n                           num_images_per_prompt=num_sample,\n                           num_samples=1,\n                           num_inference_steps=50,\n                           seed=42,\n                           image=content_image.convert('RGB'),\n                           controlnet_conditioning_scale=0.6,\n                          )\n\n#text-driven stylized synthesis\ncaption='a cat'\nimages = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,\n                           prompt=caption,\n                           negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n                           content_scale=1.0,\n                           style_scale=1.0,\n                           guidance_scale=10,\n                           num_images_per_prompt=num_sample,\n                           num_samples=1,\n                           num_inference_steps=50,\n                           seed=42,\n                           image=content_image.convert('RGB'),\n                           controlnet_conditioning_scale=0.01,\n                          )\n\n#text editing-driven stylized synthesis\ncaption='a small house'\nimages = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,\n                           prompt=caption,\n                           negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n                           content_scale=1.0,\n                           style_scale=1.0,\n                           guidance_scale=10,\n                           num_images_per_prompt=num_sample,\n                           num_samples=1,\n                           num_inference_steps=50,\n                           seed=42,\n                           image=content_image.convert('RGB'),\n                           controlnet_conditioning_scale=0.4,\n                          )\n```\n### 4 Gradio interface ⚙️\n\nWe also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:\n\n```bash\n# For Linux and Windows users (and macOS)\npython gradio/app.py \n```\nIf you don't have the resources to configure it, we provide an online [demo](https://huggingface.co/spaces/xingpng/CSGO/).\n## Demos\n<p align=\"center\">\n  <br>\n  🔥 For more results, visit our <a href=\"https://csgo-gen.github.io\"><strong>homepage</strong></a> 🔥\n</p>\n\n### Content-Style Composition\n<p align=\"center\">\n  <img src=\"assets/page1.png\">\n</p>\n\n<p align=\"center\">\n  <img src=\"assets/page4.png\">\n</p>\n\n### Cycle Translation\n<p align=\"center\">\n  <img src=\"assets/page8.png\">\n</p>\n\n### Text-Driven Style Synthesis\n<p align=\"center\">\n  <img src=\"assets/page10.png\">\n</p>\n\n### Text Editing-Driven Style Synthesis\n<p align=\"center\">\n  <img src=\"assets/page11.jpg\">\n</p>\n\n## Star History\n[![Star History Chart](https://api.star-history.com/svg?repos=instantX-research/CSGO&type=Date)](https://star-history.com/#instantX-research/CSGO&Date)\n\n\n\n## Acknowledgements\nThis project is developed by InstantX Team and Xiaohongshu, all copyright reserved.\nSincere thanks to xiaohongshu for providing the computing resources.\n\n## Citation 💖\nIf you find CSGO useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:\n```bibtex\n@article{xing2024csgo,\n       title={CSGO: Content-Style Composition in Text-to-Image Generation}, \n       author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},\n       year={2024},\n       journal = {arXiv 2408.16766},\n}\n```"
  },
  {
    "path": "gradio/app.py",
    "content": "import sys\n# sys.path.append(\"../\")\nsys.path.append(\"./\")\nimport gradio as gr\nimport torch\nfrom ip_adapter.utils import BLOCKS as BLOCKS\nfrom ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS\nfrom ip_adapter.utils import resize_content\nimport cv2\nimport numpy as np\nimport random\nfrom PIL import Image\nfrom transformers import AutoImageProcessor, AutoModel\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    StableDiffusionXLControlNetPipeline,\n\n)\nfrom ip_adapter import CSGO\nfrom transformers import BlipProcessor, BlipForConditionalGeneration\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nbase_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\nimage_encoder_path = \"h94/IP-Adapter/sdxl_models/image_encoder\"\ncsgo_ckpt ='InstantX/CSGO/csgo_4_32.bin'\npretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'\ncontrolnet_path = \"TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic\"\nweight_dtype = torch.float16\n\n\n\n\nvae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_path,\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    add_watermarker=False,\n    vae=vae\n)\npipe.enable_vae_tiling()\nblip_processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-large\")\nblip_model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-large\").to(device)\n\ntarget_content_blocks = BLOCKS['content']\ntarget_style_blocks = BLOCKS['style']\ncontrolnet_target_content_blocks = controlnet_BLOCKS['content']\ncontrolnet_target_style_blocks = controlnet_BLOCKS['style']\n\ncsgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32,\n            target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,\n            controlnet_adapter=True,\n            controlnet_target_content_blocks=controlnet_target_content_blocks,\n            controlnet_target_style_blocks=controlnet_target_style_blocks,\n            content_model_resampler=True,\n            style_model_resampler=True,\n            )\n\nMAX_SEED = np.iinfo(np.int32).max\n\ndef randomize_seed_fn(seed: int, randomize_seed: bool) -> int:\n    if randomize_seed:\n        seed = random.randint(0, MAX_SEED)\n    return seed\n\n\n\n\n\ndef get_example():\n    case = [\n        [\n            \"./assets/img_0.png\",\n            './assets/img_1.png',\n            \"Image-Driven Style Transfer\",\n            \"there is a small house with a sheep statue on top of it\",\n            1.0,\n            0.6,\n            1.0,\n        ],\n        [\n         None,\n         './assets/img_1.png',\n            \"Text-Driven Style Synthesis\",\n         \"a cat\",\n            1.0,\n         0.01,\n            1.0\n         ],\n        [\n            None,\n            './assets/img_2.png',\n            \"Text-Driven Style Synthesis\",\n            \"a building\",\n            0.5,\n            0.01,\n            1.0\n        ],\n        [\n            \"./assets/img_0.png\",\n            './assets/img_1.png',\n            \"Text Edit-Driven Style Synthesis\",\n            \"there is a small house\",\n            1.0,\n            0.4,\n            1.0\n        ],\n    ]\n    return case\n\n\ndef run_for_examples(content_image_pil,style_image_pil,target, prompt,scale_c_controlnet, scale_c, scale_s):\n    return create_image(\n        content_image_pil=content_image_pil,\n        style_image_pil=style_image_pil,\n        prompt=prompt,\n        scale_c_controlnet=scale_c_controlnet,\n        scale_c=scale_c,\n        scale_s=scale_s,\n        guidance_scale=7.0,\n        num_samples=3,\n        num_inference_steps=50,\n        seed=42,\n        target=target,\n    )\ndef randomize_seed_fn(seed: int, randomize_seed: bool) -> int:\n    if randomize_seed:\n        seed = random.randint(0, MAX_SEED)\n    return seed\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new('RGB', size=(cols * w, rows * h))\n    grid_w, grid_h = grid.size\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\ndef create_image(content_image_pil,\n                 style_image_pil,\n                 prompt,\n                 scale_c_controlnet,\n                 scale_c,\n                 scale_s,\n                 guidance_scale,\n                 num_samples,\n                 num_inference_steps,\n                 seed,\n                 target=\"Image-Driven Style Transfer\",\n):\n\n    if content_image_pil is None:\n        content_image_pil = Image.fromarray(\n            np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')\n\n    if prompt is None or prompt == '':\n        with torch.no_grad():\n            inputs = blip_processor(content_image_pil, return_tensors=\"pt\").to(device)\n            out = blip_model.generate(**inputs)\n            prompt = blip_processor.decode(out[0], skip_special_tokens=True)\n    width, height, content_image = resize_content(content_image_pil)\n    style_image = style_image_pil\n    neg_content_prompt='text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry'\n    if target ==\"Image-Driven Style Transfer\":\n        with torch.no_grad():\n            images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,\n                                   prompt=prompt,\n                                   negative_prompt=neg_content_prompt,\n                                   height=height,\n                                   width=width,\n                                   content_scale=scale_c,\n                                   style_scale=scale_s,\n                                   guidance_scale=guidance_scale,\n                                   num_images_per_prompt=num_samples,\n                                   num_inference_steps=num_inference_steps,\n                                   num_samples=1,\n                                   seed=seed,\n                                   image=content_image.convert('RGB'),\n                                   controlnet_conditioning_scale=scale_c_controlnet,\n                                   )\n\n    elif target ==\"Text-Driven Style Synthesis\":\n        content_image = Image.fromarray(\n            np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')\n        with torch.no_grad():\n            images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,\n                                   prompt=prompt,\n                                   negative_prompt=\"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n                                   height=height,\n                                   width=width,\n                                   content_scale=scale_c,\n                                   style_scale=scale_s,\n                                   guidance_scale=7,\n                                   num_images_per_prompt=num_samples,\n                                   num_inference_steps=num_inference_steps,\n                                   num_samples=1,\n                                   seed=42,\n                                   image=content_image.convert('RGB'),\n                                   controlnet_conditioning_scale=scale_c_controlnet,\n                                   )\n    elif target ==\"Text Edit-Driven Style Synthesis\":\n\n        with torch.no_grad():\n            images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,\n                                   prompt=prompt,\n                                   negative_prompt=neg_content_prompt,\n                                   height=height,\n                                   width=width,\n                                   content_scale=scale_c,\n                                   style_scale=scale_s,\n                                   guidance_scale=guidance_scale,\n                                   num_images_per_prompt=num_samples,\n                                   num_inference_steps=num_inference_steps,\n                                   num_samples=1,\n                                   seed=seed,\n                                   image=content_image.convert('RGB'),\n                                   controlnet_conditioning_scale=scale_c_controlnet,\n                                   )\n\n    return [image_grid(images, 1, num_samples)]\n\n\ndef pil_to_cv2(image_pil):\n    image_np = np.array(image_pil)\n    image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)\n    return image_cv2\n\n\n# Description\ntitle = r\"\"\"\n<h1 align=\"center\">CSGO: Content-Style Composition in Text-to-Image Generation</h1>\n\"\"\"\n\ndescription = r\"\"\"\n<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br>\nHow to use:<br>\n1. Upload a content image if you want to use image-driven style transfer.\n2. Upload a style image.\n3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are <b>Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis<b>.\n4. <b>If you choose a text-driven task, enter your desired prompt<b>.\n5.If you don't provide a prompt, the default is to use the BLIP model to generate the caption.\n6. Click the <b>Submit</b> button to begin customization.\n7. Share your stylized photo with your friends and enjoy! 😊\n\nAdvanced usage:<br>\n1. Click advanced options.\n2. Choose different guidance and steps.\n\"\"\"\n\narticle = r\"\"\"\n---\n📝 **Tips**\nIn CSGO, the more accurate the text prompts for content images, the better the content retention.\nText-driven style synthesis and text-edit-driven style synthesis are expected to be more stable in the next release.\n---\n📝 **Citation**\n<br>\nIf our work is helpful for your research or applications, please cite us via:\n```bibtex\n@article{xing2024csgo,\n       title={CSGO: Content-Style Composition in Text-to-Image Generation}, \n       author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},\n       year={2024},\n       journal = {arXiv 2408.16766},\n}\n```\n📧 **Contact**\n<br>\nIf you have any questions, please feel free to open an issue or directly reach us out at <b>xingp_ng@njust.edu.cn</b>.\n\"\"\"\n\nblock = gr.Blocks(css=\"footer {visibility: hidden}\").queue(max_size=10, api_open=False)\nwith block:\n    # description\n    gr.Markdown(title)\n    gr.Markdown(description)\n\n    with gr.Tabs():\n        with gr.Row():\n            with gr.Column():\n                with gr.Row():\n                    with gr.Column():\n                        content_image_pil = gr.Image(label=\"Content Image (optional)\", type='pil')\n                        style_image_pil = gr.Image(label=\"Style Image\", type='pil')\n\n                target = gr.Radio([\"Image-Driven Style Transfer\", \"Text-Driven Style Synthesis\", \"Text Edit-Driven Style Synthesis\"],\n                                  value=\"Image-Driven Style Transfer\",\n                                  label=\"task\")\n\n                prompt = gr.Textbox(label=\"Prompt\",\n                                    value=\"there is a small house with a sheep statue on top of it\")\n\n                scale_c_controlnet = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6,\n                                               label=\"Content Scale for controlnet\")\n                scale_c = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6, label=\"Content Scale for IPA\")\n\n                scale_s = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=1.0, label=\"Style Scale\")\n                with gr.Accordion(open=False, label=\"Advanced Options\"):\n\n                    guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label=\"guidance scale\")\n                    num_samples = gr.Slider(minimum=1, maximum=4.0, step=1.0, value=1.0, label=\"num samples\")\n                    num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50,\n                                                    label=\"num inference steps\")\n                    seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label=\"Seed Value\")\n                    randomize_seed = gr.Checkbox(label=\"Randomize seed\", value=True)\n\n                generate_button = gr.Button(\"Generate Image\")\n\n            with gr.Column():\n                generated_image = gr.Gallery(label=\"Generated Image\")\n\n        generate_button.click(\n            fn=randomize_seed_fn,\n            inputs=[seed, randomize_seed],\n            outputs=seed,\n            queue=False,\n            api_name=False,\n        ).then(\n            fn=create_image,\n            inputs=[content_image_pil,\n                    style_image_pil,\n                    prompt,\n                    scale_c_controlnet,\n                    scale_c,\n                    scale_s,\n                    guidance_scale,\n                    num_samples,\n                    num_inference_steps,\n                    seed,\n                    target,],\n            outputs=[generated_image])\n\n    gr.Examples(\n        examples=get_example(),\n        inputs=[content_image_pil,style_image_pil,target, prompt,scale_c_controlnet, scale_c, scale_s],\n        fn=run_for_examples,\n        outputs=[generated_image],\n        cache_examples=True,\n    )\n\n    gr.Markdown(article)\n\nblock.launch(server_name=\"0.0.0.0\", server_port=1234)\n"
  },
  {
    "path": "gradio/requirements.txt",
    "content": "diffusers==0.25.1\ntorch==2.0.1\ntorchaudio==2.0.2\ntorchvision==0.15.2\ntransformers==4.40.2\naccelerate\nsafetensors\neinops\nspaces==0.19.4\nomegaconf\npeft\nhuggingface-hub==0.24.5\nopencv-python\ninsightface\ngradio\ncontrolnet_aux\ngdown\npeft\n"
  },
  {
    "path": "infer/infer_CSGO.py",
    "content": "import os\nos.environ['HF_ENDPOINT']='https://hf-mirror.com'\nimport torch\nfrom ip_adapter.utils import BLOCKS as BLOCKS\nfrom ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS\nfrom ip_adapter.utils import resize_content\nimport cv2\nfrom PIL import Image\nfrom transformers import AutoImageProcessor, AutoModel\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    StableDiffusionXLControlNetPipeline,\n\n)\nfrom ip_adapter import CSGO\nfrom transformers import BlipProcessor, BlipForConditionalGeneration\n\ndevice = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n\nbase_model_path =  \"../../../base_models/stable-diffusion-xl-base-1.0\"\nimage_encoder_path = \"../../../base_models/IP-Adapter/sdxl_models/image_encoder\"\ncsgo_ckpt = \"/share2/xingpeng/DATA/blora/outputs/content_style_checkpoints_2/base_train_free_controlnet_S12_alldata_C_0_S_I_zero2_style_res_32_content_res4_drop/checkpoint-220000/ip_adapter.bin\"\npretrained_vae_name_or_path ='../../../base_models/sdxl-vae-fp16-fix'\ncontrolnet_path = \"../../../base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic\"\nweight_dtype = torch.float16\n\nblip_processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-large\")\nblip_model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-large\").to(device)\n\n\nvae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_path,\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    add_watermarker=False,\n    vae=vae\n)\npipe.enable_vae_tiling()\n\n\ntarget_content_blocks = BLOCKS['content']\ntarget_style_blocks = BLOCKS['style']\ncontrolnet_target_content_blocks = controlnet_BLOCKS['content']\ncontrolnet_target_style_blocks = controlnet_BLOCKS['style']\n\ncsgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,\n                          target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet_adapter=True,\n                              controlnet_target_content_blocks=controlnet_target_content_blocks,\n                              controlnet_target_style_blocks=controlnet_target_style_blocks,\n                              content_model_resampler=True,\n                              style_model_resampler=True,\n                              )\n\nstyle_name = 'img_1.png'\ncontent_name = 'img_0.png'\nstyle_image = Image.open(\"../assets/{}\".format(style_name)).convert('RGB')\ncontent_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')\n\n\nwith torch.no_grad():\n    inputs = blip_processor(content_image, return_tensors=\"pt\").to(device)\n    out = blip_model.generate(**inputs)\n    caption = blip_processor.decode(out[0], skip_special_tokens=True)\n\nnum_sample=1\n\nwidth,height,content_image  = resize_content(content_image)\nimages = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,\n                           prompt=caption,\n                           negative_prompt= \"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n                           height=height,\n                           width=width,\n                           content_scale=0.5,\n                           style_scale=1.0,\n                           guidance_scale=10,\n                           num_images_per_prompt=num_sample,\n                           num_samples=1,\n                           num_inference_steps=50,\n                           seed=42,\n                           image=content_image.convert('RGB'),\n                           controlnet_conditioning_scale=0.6,\n                          )\nimages[0].save(\"../assets/content_img_0_style_imag_1.png\")"
  },
  {
    "path": "ip_adapter/__init__.py",
    "content": "from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS\nfrom .ip_adapter import CSGO\n__all__ = [\n    \"IPAdapter\",\n    \"IPAdapterPlus\",\n    \"IPAdapterPlusXL\",\n    \"IPAdapterXL\",\n    \"CSGO\"\n    \"IPAdapterFull\",\n]\n"
  },
  {
    "path": "ip_adapter/attention_processor.py",
    "content": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass AttnProcessor(nn.Module):\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n        save_in_unet='down',\n        atten_control=None,\n    ):\n        super().__init__()\n        self.atten_control = atten_control\n        self.save_in_unet = save_in_unet\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor(nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n        self.skip = skip\n\n        self.atten_control = atten_control\n        self.save_in_unet = save_in_unet\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        if not self.skip:\n            # for ip-adapter\n            ip_key = self.to_k_ip(ip_hidden_states)\n            ip_value = self.to_v_ip(ip_hidden_states)\n\n            ip_key = attn.head_to_batch_dim(ip_key)\n            ip_value = attn.head_to_batch_dim(ip_value)\n\n            ip_attention_probs = attn.get_attention_scores(query, ip_key, None)\n            self.attn_map = ip_attention_probs\n            ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)\n            ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)\n\n            hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass AttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n        save_in_unet='down',\n            atten_control=None,\n    ):\n        super().__init__()\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.atten_control = atten_control\n        self.save_in_unet = save_in_unet\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater for PyTorch 2.0.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):\n        super().__init__()\n\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n        self.skip = skip\n\n        self.atten_control = atten_control\n        self.save_in_unet = save_in_unet\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if not self.skip:\n            # for ip-adapter\n            ip_key = self.to_k_ip(ip_hidden_states)\n            ip_value = self.to_v_ip(ip_hidden_states)\n\n            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            # the output of sdp = (batch, num_heads, seq_len, head_dim)\n            # TODO: add support for attn.scale when we move to Torch 2.1\n            ip_hidden_states = F.scaled_dot_product_attention(\n                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False\n            )\n            with torch.no_grad():\n                self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)\n                #print(self.attn_map.shape)\n\n            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n            ip_hidden_states = ip_hidden_states.to(query.dtype)\n\n            hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IP_CS_AttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater for PyTorch 2.0.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,\n                 skip=False,content=False, style=False):\n        super().__init__()\n\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.content_scale = content_scale\n        self.style_scale = style_scale\n        self.num_content_tokens = num_content_tokens\n        self.num_style_tokens = num_style_tokens\n        self.skip = skip\n\n        self.content = content\n        self.style = style\n\n        if self.content or self.style:\n            self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n            self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_k_ip_content =None\n        self.to_v_ip_content =None\n\n    def set_content_ipa(self,content_scale=1.0):\n\n        self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)\n        self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)\n        self.content_scale=content_scale\n        self.content =True\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens\n            encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :],\n                encoder_hidden_states[:, end_pos + self.num_content_tokens:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if not self.skip and self.content is True:\n            # print('content#####################################################')\n            # for ip-content-adapter\n            if self.to_k_ip_content is None:\n\n                ip_content_key = self.to_k_ip(ip_content_hidden_states)\n                ip_content_value = self.to_v_ip(ip_content_hidden_states)\n            else:\n                ip_content_key = self.to_k_ip_content(ip_content_hidden_states)\n                ip_content_value = self.to_v_ip_content(ip_content_hidden_states)\n\n            ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            # the output of sdp = (batch, num_heads, seq_len, head_dim)\n            # TODO: add support for attn.scale when we move to Torch 2.1\n            ip_content_hidden_states = F.scaled_dot_product_attention(\n                query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False\n            )\n\n\n            ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n            ip_content_hidden_states = ip_content_hidden_states.to(query.dtype)\n\n\n            hidden_states = hidden_states + self.content_scale * ip_content_hidden_states\n\n        if not self.skip and self.style is True:\n            # for ip-style-adapter\n            ip_style_key = self.to_k_ip(ip_style_hidden_states)\n            ip_style_value = self.to_v_ip(ip_style_hidden_states)\n\n            ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            # the output of sdp = (batch, num_heads, seq_len, head_dim)\n            # TODO: add support for attn.scale when we move to Torch 2.1\n            ip_style_hidden_states = F.scaled_dot_product_attention(\n                query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False\n            )\n\n            ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,\n                                                                                    attn.heads * head_dim)\n            ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)\n\n            hidden_states = hidden_states + self.style_scale * ip_style_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n## for controlnet\nclass CNAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):\n        self.num_tokens = num_tokens\n        self.atten_control = atten_control\n        self.save_in_unet = save_in_unet\n\n    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass CNAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.num_tokens = num_tokens\n        self.atten_control = atten_control\n        self.save_in_unet = save_in_unet\n\n    def __call__(\n            self,\n            attn,\n            hidden_states,\n            encoder_hidden_states=None,\n            attention_mask=None,\n            temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n"
  },
  {
    "path": "ip_adapter/ip_adapter.py",
    "content": "import os\nfrom typing import List\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.controlnet import MultiControlNetModel\nfrom PIL import Image\nfrom safetensors import safe_open\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\nfrom torchvision import transforms\nfrom .utils import is_torch2_available, get_generator\n\n# import torchvision.transforms.functional as Func\n\n# from .clip_style_models import CSD_CLIP, convert_state_dict\n\nif is_torch2_available():\n    from .attention_processor import (\n        AttnProcessor2_0 as AttnProcessor,\n    )\n    from .attention_processor import (\n        CNAttnProcessor2_0 as CNAttnProcessor,\n    )\n    from .attention_processor import (\n        IPAttnProcessor2_0 as IPAttnProcessor,\n    )\n    from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor\nelse:\n    from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor\nfrom .resampler import Resampler\n\nfrom transformers import AutoImageProcessor, AutoModel\n\n\nclass ImageProjModel(torch.nn.Module):\n    \"\"\"Projection Model\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.generator = None\n        self.cross_attention_dim = cross_attention_dim\n        self.clip_extra_context_tokens = clip_extra_context_tokens\n        # print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim)\n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)\n\n    def forward(self, image_embeds):\n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\nclass MLPProjModel(torch.nn.Module):\n    \"\"\"SD model with image prompt\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):\n        super().__init__()\n\n        self.proj = torch.nn.Sequential(\n            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),\n            torch.nn.GELU(),\n            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),\n            torch.nn.LayerNorm(cross_attention_dim)\n        )\n\n    def forward(self, image_embeds):\n        clip_extra_context_tokens = self.proj(image_embeds)\n        return clip_extra_context_tokens\n\n\nclass IPAdapter:\n    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=[\"block\"]):\n        self.device = device\n        self.image_encoder_path = image_encoder_path\n        self.ip_ckpt = ip_ckpt\n        self.num_tokens = num_tokens\n        self.target_blocks = target_blocks\n\n        self.pipe = sd_pipe.to(self.device)\n        self.set_ip_adapter()\n\n        # load image encoder\n        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(\n            self.device, dtype=torch.float16\n        )\n        self.clip_image_processor = CLIPImageProcessor()\n        # image proj model\n        self.image_proj_model = self.init_proj()\n\n        self.load_ip_adapter()\n\n    def init_proj(self):\n        image_proj_model = ImageProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.projection_dim,\n            clip_extra_context_tokens=self.num_tokens,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    def set_ip_adapter(self):\n        unet = self.pipe.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None:\n                attn_procs[name] = AttnProcessor()\n            else:\n                selected = False\n                for block_name in self.target_blocks:\n                    if block_name in name:\n                        selected = True\n                        break\n                if selected:\n                    attn_procs[name] = IPAttnProcessor(\n                        hidden_size=hidden_size,\n                        cross_attention_dim=cross_attention_dim,\n                        scale=1.0,\n                        num_tokens=self.num_tokens,\n                    ).to(self.device, dtype=torch.float16)\n                else:\n                    attn_procs[name] = IPAttnProcessor(\n                        hidden_size=hidden_size,\n                        cross_attention_dim=cross_attention_dim,\n                        scale=1.0,\n                        num_tokens=self.num_tokens,\n                        skip=True\n                    ).to(self.device, dtype=torch.float16)\n        unet.set_attn_processor(attn_procs)\n        if hasattr(self.pipe, \"controlnet\"):\n            if isinstance(self.pipe.controlnet, MultiControlNetModel):\n                for controlnet in self.pipe.controlnet.nets:\n                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))\n            else:\n                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))\n\n    def load_ip_adapter(self):\n        if os.path.splitext(self.ip_ckpt)[-1] == \".safetensors\":\n            state_dict = {\"image_proj\": {}, \"ip_adapter\": {}}\n            with safe_open(self.ip_ckpt, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    if key.startswith(\"image_proj.\"):\n                        state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"ip_adapter.\"):\n                        state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n        else:\n            state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"])\n        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())\n        ip_layers.load_state_dict(state_dict[\"ip_adapter\"], strict=False)\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):\n        if pil_image is not None:\n            if isinstance(pil_image, Image.Image):\n                pil_image = [pil_image]\n            clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n        else:\n            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n\n        if content_prompt_embeds is not None:\n            clip_image_embeds = clip_image_embeds - content_prompt_embeds\n\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def set_scale(self, scale):\n        for attn_processor in self.pipe.unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                attn_processor.scale = scale\n\n    def generate(\n            self,\n            pil_image=None,\n            clip_image_embeds=None,\n            prompt=None,\n            negative_prompt=None,\n            scale=1.0,\n            num_samples=4,\n            seed=None,\n            guidance_scale=7.5,\n            num_inference_steps=30,\n            neg_content_emb=None,\n            **kwargs,\n    ):\n        self.set_scale(scale)\n\n        if pil_image is not None:\n            num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n        else:\n            num_prompts = clip_image_embeds.size(0)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(\n            pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb\n        )\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(\n                prompt,\n                device=self.device,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\nclass IPAdapter_CS:\n    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,\n                 num_style_tokens=4,\n                 target_content_blocks=[\"block\"], target_style_blocks=[\"block\"], content_image_encoder_path=None,\n                  controlnet_adapter=False,\n                 controlnet_target_content_blocks=None,\n                 controlnet_target_style_blocks=None,\n                 content_model_resampler=False,\n                 style_model_resampler=False,\n                ):\n        self.device = device\n        self.image_encoder_path = image_encoder_path\n        self.ip_ckpt = ip_ckpt\n        self.num_content_tokens = num_content_tokens\n        self.num_style_tokens = num_style_tokens\n        self.content_target_blocks = target_content_blocks\n        self.style_target_blocks = target_style_blocks\n\n        self.content_model_resampler = content_model_resampler\n        self.style_model_resampler = style_model_resampler\n\n        self.controlnet_adapter = controlnet_adapter\n        self.controlnet_target_content_blocks = controlnet_target_content_blocks\n        self.controlnet_target_style_blocks = controlnet_target_style_blocks\n\n        self.pipe = sd_pipe.to(self.device)\n        self.set_ip_adapter()\n        self.content_image_encoder_path = content_image_encoder_path\n\n\n        # load image encoder\n        if content_image_encoder_path is not None:\n            self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device,\n                                                                                                  dtype=torch.float16)\n            self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path)\n        else:\n            self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(\n                self.device, dtype=torch.float16\n            )\n            self.content_image_processor = CLIPImageProcessor()\n        # model.requires_grad_(False)\n\n        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(\n            self.device, dtype=torch.float16\n        )\n        # if self.use_CSD is not None:\n        #     self.style_image_encoder = CSD_CLIP(\"vit_large\", \"default\",self.use_CSD+\"/ViT-L-14.pt\")\n        #     model_path = self.use_CSD+\"/checkpoint.pth\"\n        #     checkpoint = torch.load(model_path, map_location=\"cpu\")\n        #     state_dict = convert_state_dict(checkpoint['model_state_dict'])\n        #     self.style_image_encoder.load_state_dict(state_dict, strict=False)\n        #\n        #     normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n        #     self.style_preprocess = transforms.Compose([\n        #         transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC),\n        #         transforms.CenterCrop(224),\n        #         transforms.ToTensor(),\n        #         normalize,\n        #     ])\n\n        self.clip_image_processor = CLIPImageProcessor()\n        # image proj model\n        self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content',\n                                                       model_resampler=self.content_model_resampler)\n        self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',\n                                                     model_resampler=self.style_model_resampler)\n\n        self.load_ip_adapter()\n\n    def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):\n\n        # print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim)\n        if content_or_style_ == 'content' and self.content_image_encoder_path is not None:\n            image_proj_model = ImageProjModel(\n                cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n                clip_embeddings_dim=self.content_image_encoder.config.projection_dim,\n                clip_extra_context_tokens=num_tokens,\n            ).to(self.device, dtype=torch.float16)\n            return image_proj_model\n\n        image_proj_model = ImageProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.projection_dim,\n            clip_extra_context_tokens=num_tokens,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    def set_ip_adapter(self):\n        unet = self.pipe.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None:\n                attn_procs[name] = AttnProcessor()\n            else:\n                # layername_id += 1\n                selected = False\n                for block_name in self.style_target_blocks:\n                    if block_name in name:\n                        selected = True\n                        # print(name)\n                        attn_procs[name] = IP_CS_AttnProcessor(\n                            hidden_size=hidden_size,\n                            cross_attention_dim=cross_attention_dim,\n                            style_scale=1.0,\n                            style=True,\n                            num_content_tokens=self.num_content_tokens,\n                            num_style_tokens=self.num_style_tokens,\n                        )\n                for block_name in self.content_target_blocks:\n                    if block_name in name:\n                        # selected = True\n                        if selected is False:\n                            attn_procs[name] = IP_CS_AttnProcessor(\n                                hidden_size=hidden_size,\n                                cross_attention_dim=cross_attention_dim,\n                                content_scale=1.0,\n                                content=True,\n                                num_content_tokens=self.num_content_tokens,\n                                num_style_tokens=self.num_style_tokens,\n                            )\n                        else:\n                            attn_procs[name].set_content_ipa(content_scale=1.0)\n                            # attn_procs[name].content=True\n\n                if selected is False:\n                    attn_procs[name] = IP_CS_AttnProcessor(\n                        hidden_size=hidden_size,\n                        cross_attention_dim=cross_attention_dim,\n                        num_content_tokens=self.num_content_tokens,\n                        num_style_tokens=self.num_style_tokens,\n                        skip=True,\n                    )\n\n                attn_procs[name].to(self.device, dtype=torch.float16)\n        unet.set_attn_processor(attn_procs)\n        if hasattr(self.pipe, \"controlnet\"):\n            if self.controlnet_adapter is False:\n                if isinstance(self.pipe.controlnet, MultiControlNetModel):\n                    for controlnet in self.pipe.controlnet.nets:\n                        controlnet.set_attn_processor(CNAttnProcessor(\n                            num_tokens=self.num_content_tokens + self.num_style_tokens))\n                else:\n                    self.pipe.controlnet.set_attn_processor(CNAttnProcessor(\n                        num_tokens=self.num_content_tokens + self.num_style_tokens))\n\n            else:\n                controlnet_attn_procs = {}\n                controlnet_style_target_blocks = self.controlnet_target_style_blocks\n                controlnet_content_target_blocks = self.controlnet_target_content_blocks\n                for name in self.pipe.controlnet.attn_processors.keys():\n                    # print(name)\n                    cross_attention_dim = None if name.endswith(\n                        \"attn1.processor\") else self.pipe.controlnet.config.cross_attention_dim\n                    if name.startswith(\"mid_block\"):\n                        hidden_size = self.pipe.controlnet.config.block_out_channels[-1]\n                    elif name.startswith(\"up_blocks\"):\n                        block_id = int(name[len(\"up_blocks.\")])\n                        hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id]\n                    elif name.startswith(\"down_blocks\"):\n                        block_id = int(name[len(\"down_blocks.\")])\n                        hidden_size = self.pipe.controlnet.config.block_out_channels[block_id]\n                    if cross_attention_dim is None:\n                        # layername_id += 1\n                        controlnet_attn_procs[name] = AttnProcessor()\n\n                    else:\n                        # layername_id += 1\n                        selected = False\n                        for block_name in controlnet_style_target_blocks:\n                            if block_name in name:\n                                selected = True\n                                # print(name)\n                                controlnet_attn_procs[name] = IP_CS_AttnProcessor(\n                                    hidden_size=hidden_size,\n                                    cross_attention_dim=cross_attention_dim,\n                                    style_scale=1.0,\n                                    style=True,\n                                    num_content_tokens=self.num_content_tokens,\n                                    num_style_tokens=self.num_style_tokens,\n                                )\n\n                        for block_name in controlnet_content_target_blocks:\n                            if block_name in name:\n                                if selected is False:\n                                    controlnet_attn_procs[name] = IP_CS_AttnProcessor(\n                                        hidden_size=hidden_size,\n                                        cross_attention_dim=cross_attention_dim,\n                                        content_scale=1.0,\n                                        content=True,\n                                        num_content_tokens=self.num_content_tokens,\n                                        num_style_tokens=self.num_style_tokens,\n                                    )\n\n                                    selected = True\n                                elif selected is True:\n                                    controlnet_attn_procs[name].set_content_ipa(content_scale=1.0)\n\n                                # if args.content_image_encoder_type !='dinov2':\n                                #     weights = {\n                                #         \"to_k_ip.weight\": state_dict[\"ip_adapter\"][str(layername_id) + \".to_k_ip.weight\"],\n                                #         \"to_v_ip.weight\": state_dict[\"ip_adapter\"][str(layername_id) + \".to_v_ip.weight\"],\n                                #     }\n                                #     attn_procs[name].load_state_dict(weights)\n                        if selected is False:\n                            controlnet_attn_procs[name] = IP_CS_AttnProcessor(\n                                hidden_size=hidden_size,\n                                cross_attention_dim=cross_attention_dim,\n                                num_content_tokens=self.num_content_tokens,\n                                num_style_tokens=self.num_style_tokens,\n                                skip=True,\n                            )\n                        controlnet_attn_procs[name].to(self.device, dtype=torch.float16)\n                        # layer_name = name.split(\".processor\")[0]\n                        # # print(state_dict[\"ip_adapter\"].keys())\n                        # weights = {\n                        #     \"to_k_ip.weight\": state_dict[\"ip_adapter\"][str(layername_id) + \".to_k_ip.weight\"],\n                        #     \"to_v_ip.weight\": state_dict[\"ip_adapter\"][str(layername_id) + \".to_v_ip.weight\"],\n                        # }\n                        # attn_procs[name].load_state_dict(weights)\n                self.pipe.controlnet.set_attn_processor(controlnet_attn_procs)\n\n    def load_ip_adapter(self):\n        if os.path.splitext(self.ip_ckpt)[-1] == \".safetensors\":\n            state_dict = {\"content_image_proj\": {}, \"style_image_proj\": {}, \"ip_adapter\": {}}\n            with safe_open(self.ip_ckpt, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    if key.startswith(\"content_image_proj.\"):\n                        state_dict[\"content_image_proj\"][key.replace(\"content_image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"style_image_proj.\"):\n                        state_dict[\"style_image_proj\"][key.replace(\"style_image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"ip_adapter.\"):\n                        state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n        else:\n            state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n        self.content_image_proj_model.load_state_dict(state_dict[\"content_image_proj\"])\n        self.style_image_proj_model.load_state_dict(state_dict[\"style_image_proj\"])\n\n        if 'conv_in_unet_sd' in state_dict.keys():\n            self.pipe.unet.conv_in.load_state_dict(state_dict[\"conv_in_unet_sd\"], strict=True)\n        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())\n        ip_layers.load_state_dict(state_dict[\"ip_adapter\"], strict=False)\n\n        if self.controlnet_adapter is True:\n            print('loading controlnet_adapter')\n            self.pipe.controlnet.load_state_dict(state_dict[\"controlnet_adapter_modules\"], strict=False)\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None,\n                         content_or_style_=''):\n        # if pil_image is not None:\n        #     if isinstance(pil_image, Image.Image):\n        #         pil_image = [pil_image]\n        #     clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        #     clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n        # else:\n        #     clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n\n        # if content_prompt_embeds is not None:\n        #     clip_image_embeds = clip_image_embeds - content_prompt_embeds\n\n        if content_or_style_ == 'content':\n            if pil_image is not None:\n                if isinstance(pil_image, Image.Image):\n                    pil_image = [pil_image]\n                if self.content_image_proj_model is not None:\n                    clip_image = self.content_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                    clip_image_embeds = self.content_image_encoder(\n                        clip_image.to(self.device, dtype=torch.float16)).image_embeds\n                else:\n                    clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                    clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n            else:\n                clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n\n            image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)\n            uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))\n            return image_prompt_embeds, uncond_image_prompt_embeds\n        if content_or_style_ == 'style':\n            if pil_image is not None:\n                if self.use_CSD is not None:\n                    clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32)\n                    clip_image_embeds = self.style_image_encoder(clip_image)\n                else:\n                    if isinstance(pil_image, Image.Image):\n                        pil_image = [pil_image]\n                    clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                    clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n\n\n            else:\n                clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n            image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)\n            uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))\n            return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def set_scale(self, content_scale, style_scale):\n        for attn_processor in self.pipe.unet.attn_processors.values():\n            if isinstance(attn_processor, IP_CS_AttnProcessor):\n                if attn_processor.content is True:\n                    attn_processor.content_scale = content_scale\n\n                if attn_processor.style is True:\n                    attn_processor.style_scale = style_scale\n                    # print('style_scale:',style_scale)\n        if self.controlnet_adapter is not None:\n            for attn_processor in self.pipe.controlnet.attn_processors.values():\n\n                if isinstance(attn_processor, IP_CS_AttnProcessor):\n                    if attn_processor.content is True:\n                        attn_processor.content_scale = content_scale\n                        # print(content_scale)\n\n                    if attn_processor.style is True:\n                        attn_processor.style_scale = style_scale\n\n    def generate(\n            self,\n            pil_content_image=None,\n            pil_style_image=None,\n            clip_content_image_embeds=None,\n            clip_style_image_embeds=None,\n            prompt=None,\n            negative_prompt=None,\n            content_scale=1.0,\n            style_scale=1.0,\n            num_samples=4,\n            seed=None,\n            guidance_scale=7.5,\n            num_inference_steps=30,\n            neg_content_emb=None,\n            **kwargs,\n    ):\n        self.set_scale(content_scale, style_scale)\n\n        if pil_content_image is not None:\n            num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)\n        else:\n            num_prompts = clip_content_image_embeds.size(0)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(\n            pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds\n        )\n        style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(\n            pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds\n        )\n\n        bs_embed, seq_len, _ = content_image_prompt_embeds.shape\n        content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)\n        content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,\n                                                                                     -1)\n\n        bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape\n        style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)\n        style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)\n        uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,\n                                                                                 -1)\n\n        with torch.inference_mode():\n            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(\n                prompt,\n                device=self.device,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds_,\n                                                uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],\n                                               dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\nclass IPAdapterXL_CS(IPAdapter_CS):\n    \"\"\"SDXL\"\"\"\n\n    def generate(\n            self,\n            pil_content_image,\n            pil_style_image,\n            prompt=None,\n            negative_prompt=None,\n            content_scale=1.0,\n            style_scale=1.0,\n            num_samples=4,\n            seed=None,\n            content_image_embeds=None,\n            style_image_embeds=None,\n            num_inference_steps=30,\n            neg_content_emb=None,\n            neg_content_prompt=None,\n            neg_content_scale=1.0,\n            **kwargs,\n    ):\n        self.set_scale(content_scale, style_scale)\n\n        num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image,\n                                                                                                content_image_embeds,\n                                                                                                content_or_style_='content')\n\n\n\n        style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image,\n                                                                                            style_image_embeds,\n                                                                                            content_or_style_='style')\n\n        bs_embed, seq_len, _ = content_image_prompt_embeds.shape\n\n        content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)\n        content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,\n                                                                                     -1)\n        bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape\n        style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)\n        style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)\n        uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,\n                                                                                 -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds,\n                                                uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],\n                                               dim=1)\n\n        self.generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=self.generator,\n            **kwargs,\n        ).images\n        return images\n\n\nclass CSGO(IPAdapterXL_CS):\n    \"\"\"SDXL\"\"\"\n\n    def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):\n        if content_or_style_ == 'content':\n            if model_resampler:\n                image_proj_model = Resampler(\n                    dim=self.pipe.unet.config.cross_attention_dim,\n                    depth=4,\n                    dim_head=64,\n                    heads=12,\n                    num_queries=num_tokens,\n                    embedding_dim=self.content_image_encoder.config.hidden_size,\n                    output_dim=self.pipe.unet.config.cross_attention_dim,\n                    ff_mult=4,\n                ).to(self.device, dtype=torch.float16)\n            else:\n                image_proj_model = ImageProjModel(\n                    cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n                    clip_embeddings_dim=self.image_encoder.config.projection_dim,\n                    clip_extra_context_tokens=num_tokens,\n                ).to(self.device, dtype=torch.float16)\n        if content_or_style_ == 'style':\n            if model_resampler:\n                image_proj_model = Resampler(\n                    dim=self.pipe.unet.config.cross_attention_dim,\n                    depth=4,\n                    dim_head=64,\n                    heads=12,\n                    num_queries=num_tokens,\n                    embedding_dim=self.content_image_encoder.config.hidden_size,\n                    output_dim=self.pipe.unet.config.cross_attention_dim,\n                    ff_mult=4,\n                ).to(self.device, dtype=torch.float16)\n            else:\n                image_proj_model = ImageProjModel(\n                    cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n                    clip_embeddings_dim=self.image_encoder.config.projection_dim,\n                    clip_extra_context_tokens=num_tokens,\n                ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''):\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        if content_or_style_ == 'style':\n\n            if self.style_model_resampler:\n                clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),\n                                                       output_hidden_states=True).hidden_states[-2]\n                image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)\n                uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))\n            else:\n\n\n                clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n                image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)\n                uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))\n            return image_prompt_embeds, uncond_image_prompt_embeds\n\n\n        else:\n\n            if self.content_image_encoder_path is not None:\n                clip_image = self.content_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16),\n                                                     output_hidden_states=True)\n                clip_image_embeds = outputs.last_hidden_state\n                image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)\n\n                # uncond_clip_image_embeds = self.image_encoder(\n                #     torch.zeros_like(clip_image), output_hidden_states=True\n                # ).last_hidden_state\n                uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))\n                return image_prompt_embeds, uncond_image_prompt_embeds\n\n            else:\n                if self.content_model_resampler:\n\n                    clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n\n                    clip_image = clip_image.to(self.device, dtype=torch.float16)\n                    clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n                    # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)\n                    image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)\n                    # uncond_clip_image_embeds = self.image_encoder(\n                    #             torch.zeros_like(clip_image), output_hidden_states=True\n                    #         ).hidden_states[-2]\n                    uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))\n                else:\n                    clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n                    clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds\n                    image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)\n                    uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))\n\n                return image_prompt_embeds, uncond_image_prompt_embeds\n\n        #     # clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        #     clip_image = clip_image.to(self.device, dtype=torch.float16)\n        #     clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        #     image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)\n        #     uncond_clip_image_embeds = self.image_encoder(\n        #         torch.zeros_like(clip_image), output_hidden_states=True\n        #     ).hidden_states[-2]\n        #     uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds)\n        # return image_prompt_embeds, uncond_image_prompt_embeds\n\n\nclass IPAdapterXL(IPAdapter):\n    \"\"\"SDXL\"\"\"\n\n    def generate(\n            self,\n            pil_image,\n            prompt=None,\n            negative_prompt=None,\n            scale=1.0,\n            num_samples=4,\n            seed=None,\n            num_inference_steps=30,\n            neg_content_emb=None,\n            neg_content_prompt=None,\n            neg_content_scale=1.0,\n            **kwargs,\n    ):\n        self.set_scale(scale)\n\n        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        if neg_content_emb is None:\n            if neg_content_prompt is not None:\n                with torch.inference_mode():\n                    (\n                        prompt_embeds_,  # torch.Size([1, 77, 2048])\n                        negative_prompt_embeds_,\n                        pooled_prompt_embeds_,  # torch.Size([1, 1280])\n                        negative_pooled_prompt_embeds_,\n                    ) = self.pipe.encode_prompt(\n                        neg_content_prompt,\n                        num_images_per_prompt=num_samples,\n                        do_classifier_free_guidance=True,\n                        negative_prompt=negative_prompt,\n                    )\n                    pooled_prompt_embeds_ *= neg_content_scale\n            else:\n                pooled_prompt_embeds_ = neg_content_emb\n        else:\n            pooled_prompt_embeds_ = None\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image,\n                                                                                content_prompt_embeds=pooled_prompt_embeds_)\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)\n\n        self.generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=self.generator,\n            **kwargs,\n        ).images\n\n        return images\n\n\nclass IPAdapterPlus(IPAdapter):\n    \"\"\"IP-Adapter with fine-grained features\"\"\"\n\n    def init_proj(self):\n        image_proj_model = Resampler(\n            dim=self.pipe.unet.config.cross_attention_dim,\n            depth=4,\n            dim_head=64,\n            heads=12,\n            num_queries=self.num_tokens,\n            embedding_dim=self.image_encoder.config.hidden_size,\n            output_dim=self.pipe.unet.config.cross_attention_dim,\n            ff_mult=4,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        clip_image = clip_image.to(self.device, dtype=torch.float16)\n        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_clip_image_embeds = self.image_encoder(\n            torch.zeros_like(clip_image), output_hidden_states=True\n        ).hidden_states[-2]\n        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n\nclass IPAdapterFull(IPAdapterPlus):\n    \"\"\"IP-Adapter with full features\"\"\"\n\n    def init_proj(self):\n        image_proj_model = MLPProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.image_encoder.config.hidden_size,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n\nclass IPAdapterPlusXL(IPAdapter):\n    \"\"\"SDXL\"\"\"\n\n    def init_proj(self):\n        image_proj_model = Resampler(\n            dim=1280,\n            depth=4,\n            dim_head=64,\n            heads=20,\n            num_queries=self.num_tokens,\n            embedding_dim=self.image_encoder.config.hidden_size,\n            output_dim=self.pipe.unet.config.cross_attention_dim,\n            ff_mult=4,\n        ).to(self.device, dtype=torch.float16)\n        return image_proj_model\n\n    @torch.inference_mode()\n    def get_image_embeds(self, pil_image):\n        if isinstance(pil_image, Image.Image):\n            pil_image = [pil_image]\n        clip_image = self.clip_image_processor(images=pil_image, return_tensors=\"pt\").pixel_values\n        clip_image = clip_image.to(self.device, dtype=torch.float16)\n        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]\n        image_prompt_embeds = self.image_proj_model(clip_image_embeds)\n        uncond_clip_image_embeds = self.image_encoder(\n            torch.zeros_like(clip_image), output_hidden_states=True\n        ).hidden_states[-2]\n        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)\n        return image_prompt_embeds, uncond_image_prompt_embeds\n\n    def generate(\n            self,\n            pil_image,\n            prompt=None,\n            negative_prompt=None,\n            scale=1.0,\n            num_samples=4,\n            seed=None,\n            num_inference_steps=30,\n            **kwargs,\n    ):\n        self.set_scale(scale)\n\n        num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)\n\n        if prompt is None:\n            prompt = \"best quality, high quality\"\n        if negative_prompt is None:\n            negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality\"\n\n        if not isinstance(prompt, List):\n            prompt = [prompt] * num_prompts\n        if not isinstance(negative_prompt, List):\n            negative_prompt = [negative_prompt] * num_prompts\n\n        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)\n        bs_embed, seq_len, _ = image_prompt_embeds.shape\n        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)\n        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)\n        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)\n\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.pipe.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n            )\n            prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)\n\n        generator = get_generator(seed, self.device)\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            **kwargs,\n        ).images\n\n        return images\n"
  },
  {
    "path": "ip_adapter/resampler.py",
    "content": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom einops.layers.torch import Rearrange\n\n\n# FFN\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head**-0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass Resampler(nn.Module):\n    def __init__(\n        self,\n        dim=1024,\n        depth=8,\n        dim_head=64,\n        heads=16,\n        num_queries=8,\n        embedding_dim=768,\n        output_dim=1024,\n        ff_mult=4,\n        max_seq_len: int = 257,  # CLIP tokens + CLS token\n        apply_pos_emb: bool = False,\n        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence\n    ):\n        super().__init__()\n        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None\n\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)\n\n        self.proj_in = nn.Linear(embedding_dim, dim)\n\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.to_latents_from_mean_pooled_seq = (\n            nn.Sequential(\n                nn.LayerNorm(dim),\n                nn.Linear(dim, dim * num_latents_mean_pooled),\n                Rearrange(\"b (n d) -> b n d\", n=num_latents_mean_pooled),\n            )\n            if num_latents_mean_pooled > 0\n            else None\n        )\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        FeedForward(dim=dim, mult=ff_mult),\n                    ]\n                )\n            )\n\n    def forward(self, x):\n        if self.pos_emb is not None:\n            n, device = x.shape[1], x.device\n            pos_emb = self.pos_emb(torch.arange(n, device=device))\n            x = x + pos_emb\n\n        latents = self.latents.repeat(x.size(0), 1, 1)\n\n        x = self.proj_in(x)\n\n        if self.to_latents_from_mean_pooled_seq:\n            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))\n            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)\n            latents = torch.cat((meanpooled_latents, latents), dim=-2)\n\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n\ndef masked_mean(t, *, dim, mask=None):\n    if mask is None:\n        return t.mean(dim=dim)\n\n    denom = mask.sum(dim=dim, keepdim=True)\n    mask = rearrange(mask, \"b n -> b n 1\")\n    masked_t = t.masked_fill(~mask, 0.0)\n\n    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)\n"
  },
  {
    "path": "ip_adapter/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\nBLOCKS = {\n    'content': ['down_blocks'],\n    'style': [\"up_blocks\"],\n\n}\n\ncontrolnet_BLOCKS = {\n    'content': [],\n    'style': [\"down_blocks\"],\n}\n\n\ndef resize_width_height(width, height, min_short_side=512, max_long_side=1024):\n\n    if width < height:\n\n        if width < min_short_side:\n            scale_factor = min_short_side / width\n            new_width = min_short_side\n            new_height = int(height * scale_factor)\n        else:\n            new_width, new_height = width, height\n    else:\n\n        if height < min_short_side:\n            scale_factor = min_short_side / height\n            new_width = int(width * scale_factor)\n            new_height = min_short_side\n        else:\n            new_width, new_height = width, height\n\n    if max(new_width, new_height) > max_long_side:\n        scale_factor = max_long_side / max(new_width, new_height)\n        new_width = int(new_width * scale_factor)\n        new_height = int(new_height * scale_factor)\n    return new_width, new_height\n\ndef resize_content(content_image):\n    max_long_side = 1024\n    min_short_side = 1024\n\n    new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],\n                                                min_short_side=min_short_side, max_long_side=max_long_side)\n    height = new_height // 16 * 16\n    width = new_width // 16 * 16\n    content_image = content_image.resize((width, height))\n\n    return width,height,content_image\n\nattn_maps = {}\ndef hook_fn(name):\n    def forward_hook(module, input, output):\n        if hasattr(module.processor, \"attn_map\"):\n            attn_maps[name] = module.processor.attn_map\n            del module.processor.attn_map\n\n    return forward_hook\n\ndef register_cross_attention_hook(unet):\n    for name, module in unet.named_modules():\n        if name.split('.')[-1].startswith('attn2'):\n            module.register_forward_hook(hook_fn(name))\n\n    return unet\n\ndef upscale(attn_map, target_size):\n    attn_map = torch.mean(attn_map, dim=0)\n    attn_map = attn_map.permute(1,0)\n    temp_size = None\n\n    for i in range(0,5):\n        scale = 2 ** i\n        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:\n            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))\n            break\n\n    assert temp_size is not None, \"temp_size cannot is None\"\n\n    attn_map = attn_map.view(attn_map.shape[0], *temp_size)\n\n    attn_map = F.interpolate(\n        attn_map.unsqueeze(0).to(dtype=torch.float32),\n        size=target_size,\n        mode='bilinear',\n        align_corners=False\n    )[0]\n\n    attn_map = torch.softmax(attn_map, dim=0)\n    return attn_map\ndef get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):\n\n    idx = 0 if instance_or_negative else 1\n    net_attn_maps = []\n\n    for name, attn_map in attn_maps.items():\n        attn_map = attn_map.cpu() if detach else attn_map\n        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()\n        attn_map = upscale(attn_map, image_size) \n        net_attn_maps.append(attn_map) \n\n    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)\n\n    return net_attn_maps\n\ndef attnmaps2images(net_attn_maps):\n\n    #total_attn_scores = 0\n    images = []\n\n    for attn_map in net_attn_maps:\n        attn_map = attn_map.cpu().numpy()\n        #total_attn_scores += attn_map.mean().item()\n\n        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255\n        normalized_attn_map = normalized_attn_map.astype(np.uint8)\n        #print(\"norm: \", normalized_attn_map.shape)\n        image = Image.fromarray(normalized_attn_map)\n\n        #image = fix_save_attn_map(attn_map)\n        images.append(image)\n\n    #print(total_attn_scores)\n    return images\ndef is_torch2_available():\n    return hasattr(F, \"scaled_dot_product_attention\")\n\ndef get_generator(seed, device):\n\n    if seed is not None:\n        if isinstance(seed, list):\n            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]\n        else:\n            generator = torch.Generator(device).manual_seed(seed)\n    else:\n        generator = None\n\n    return generator"
  },
  {
    "path": "requirements.txt",
    "content": "diffusers==0.25.1\ntorch==2.0.1\ntorchaudio==2.0.2\ntorchvision==0.15.2\ntransformers==4.40.2\naccelerate\nsafetensors\neinops\nspaces==0.19.4\nomegaconf\npeft\nhuggingface-hub==0.24.5\nopencv-python\ninsightface\ngradio\ncontrolnet_aux\ngdown\npeft\n"
  }
]