Repository: instantX-research/CSGO
Branch: main
Commit: fefec09cf680
Files: 11
Total size: 26.0 MB
Directory structure:
gitextract_n_50hphl/
├── README.md
├── gradio/
│ ├── app.py
│ └── requirements.txt
├── infer/
│ ├── infer_CSGO.py
│ └── infer_csgo.ipynb
├── ip_adapter/
│ ├── __init__.py
│ ├── attention_processor.py
│ ├── ip_adapter.py
│ ├── resampler.py
│ └── utils.py
└── requirements.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
<div align="center">
<h1>CSGO: Content-Style Composition in Text-to-Image Generation</h1>
[**Peng Xing**](https://github.com/xingp-ng)<sup>12*</sup> · [**Haofan Wang**](https://haofanwang.github.io/)<sup>1*</sup> · [**Yanpeng Sun**](https://scholar.google.com.hk/citations?user=a3FI8c4AAAAJ&hl=zh-CN&oi=ao/)<sup>2</sup> · [**Qixun Wang**](https://github.com/wangqixun)<sup>1</sup> · [**Xu Bai**](https://huggingface.co/baymin0220)<sup>13</sup> · [**Hao Ai**](https://github.com/aihao2000)<sup>14</sup> · [**Renyuan Huang**](https://github.com/DannHuang)<sup>15</sup>
[**Zechao Li**](https://zechao-li.github.io/)<sup>2✉</sup>
<sup>1</sup>InstantX Team · <sup>2</sup>Nanjing University of Science and Technology · <sup>3</sup>Xiaohongshu · <sup>4</sup>Beihang University · <sup>5</sup>Peking University
<sup>*</sup>equal contributions, <sup>✉</sup>corresponding authors
<a href='https://csgo-gen.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://arxiv.org/abs/2404.02733'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
[](https://huggingface.co/InstantX/CSGO)
[](https://huggingface.co/spaces/xingpng/CSGO/)
[](https://github.com/instantX-research/CSGO)
</div>
## Updates 🔥
[//]: # (- **`2024/07/19`**: ✨ We support 🎞️ portrait video editing (aka v2v)! More to see [here](assets/docs/changelog/2024-07-19.md).)
[//]: # (- **`2024/07/17`**: 🍎 We support macOS with Apple Silicon, modified from [jeethu](https://github.com/jeethu)'s PR [#143](https://github.com/KwaiVGI/LivePortrait/pull/143).)
[//]: # (- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md).)
[//]: # (- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)!)
[//]: # (Continuous updates, stay tuned!)
- **`2024/09/04`**: 🔥 We released the gradio code. You can simply configure it and use it directly.
- **`2024/09/03`**: 🔥 We released the online demo on [Hugggingface](https://huggingface.co/spaces/xingpng/CSGO/).
- **`2024/09/03`**: 🔥 We released the [pre-trained weight](https://huggingface.co/InstantX/CSGO).
- **`2024/09/03`**: 🔥 We released the initial version of the inference code.
- **`2024/08/30`**: 🔥 We released the technical report on [arXiv](https://arxiv.org/pdf/2408.16766)
- **`2024/07/15`**: 🔥 We released the [homepage](https://csgo-gen.github.io).
## Plan 💪
- [x] technical report
- [x] inference code
- [x] pre-trained weight [4_16]
- [x] pre-trained weight [4_32]
- [x] online demo
- [ ] pre-trained weight_v2 [4_32]
- [ ] IMAGStyle dataset
- [ ] training code
- [ ] more pre-trained weight
## Introduction 📖
This repo, named **CSGO**, contains the official PyTorch implementation of our paper [CSGO: Content-Style Composition in Text-to-Image Generation](https://arxiv.org/pdf/).
We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
## Pipeline 💻
<p align="center">
<img src="assets/image3_1.jpg">
</p>
## Capabilities 🚅
🔥 Our CSGO achieves **image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis**.
🔥 For more results, visit our <a href="https://csgo-gen.github.io"><strong>homepage</strong></a> 🔥
<p align="center">
<img src="assets/vis.jpg">
</p>
## Getting Started 🏁
### 1. Clone the code and prepare the environment
```bash
git clone https://github.com/instantX-research/CSGO
cd CSGO
# create env using conda
conda create -n CSGO python=3.9
conda activate CSGO
# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt
```
### 2. Download pretrained weights
We currently release two model weights.
| Mode | content token | style token | Other |
|:----------------:|:-----------:|:-----------:|:---------------------------------:|
| csgo.bin |4|16| - |
| csgo_4_32.bin |4|32| Deepspeed zero2 |
| csgo_4_32_v2.bin |4|32| Deepspeed zero2+more(coming soon) |
The easiest way to download the pretrained weights is from HuggingFace:
```bash
# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
git lfs install
# clone and move the weights
git clone https://huggingface.co/InstantX/CSGO
```
Our method is fully compatible with [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix), [ControlNet](https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic), and [Image Encoder](https://huggingface.co/h94/IP-Adapter/tree/main/sdxl_models/image_encoder).
Please download them and place them in the ./base_models folder.
tips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following:
```bash
git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors
```
### 3. Inference 🚀
```python
import torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from PIL import Image
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
)
from ip_adapter import CSGO
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model_path = "./base_models/stable-diffusion-xl-base-1.0"
image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "./CSGO/csgo.bin"
pretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix'
controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
add_watermarker=False,
vae=vae
)
pipe.enable_vae_tiling()
target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet=False,controlnet_adapter=True,
controlnet_target_content_blocks=controlnet_target_content_blocks,
controlnet_target_style_blocks=controlnet_target_style_blocks,
content_model_resampler=True,
style_model_resampler=True,
load_controlnet=False,
)
style_name = 'img_0.png'
content_name = 'img_0.png'
style_image = "../assets/{}".format(style_name)
content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')
caption ='a small house with a sheep statue on top of it'
num_sample=4
#image-driven style transfer
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.6,
)
#text-driven stylized synthesis
caption='a cat'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.01,
)
#text editing-driven stylized synthesis
caption='a small house'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.4,
)
```
### 4 Gradio interface ⚙️
We also provide a Gradio <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a> interface for a better experience, just run by:
```bash
# For Linux and Windows users (and macOS)
python gradio/app.py
```
If you don't have the resources to configure it, we provide an online [demo](https://huggingface.co/spaces/xingpng/CSGO/).
## Demos
<p align="center">
<br>
🔥 For more results, visit our <a href="https://csgo-gen.github.io"><strong>homepage</strong></a> 🔥
</p>
### Content-Style Composition
<p align="center">
<img src="assets/page1.png">
</p>
<p align="center">
<img src="assets/page4.png">
</p>
### Cycle Translation
<p align="center">
<img src="assets/page8.png">
</p>
### Text-Driven Style Synthesis
<p align="center">
<img src="assets/page10.png">
</p>
### Text Editing-Driven Style Synthesis
<p align="center">
<img src="assets/page11.jpg">
</p>
## Star History
[](https://star-history.com/#instantX-research/CSGO&Date)
## Acknowledgements
This project is developed by InstantX Team and Xiaohongshu, all copyright reserved.
Sincere thanks to xiaohongshu for providing the computing resources.
## Citation 💖
If you find CSGO useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
```bibtex
@article{xing2024csgo,
title={CSGO: Content-Style Composition in Text-to-Image Generation},
author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
year={2024},
journal = {arXiv 2408.16766},
}
```
================================================
FILE: gradio/app.py
================================================
import sys
# sys.path.append("../")
sys.path.append("./")
import gradio as gr
import torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from ip_adapter.utils import resize_content
import cv2
import numpy as np
import random
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
)
from ip_adapter import CSGO
from transformers import BlipProcessor, BlipForConditionalGeneration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "h94/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt ='InstantX/CSGO/csgo_4_32.bin'
pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'
controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
add_watermarker=False,
vae=vae
)
pipe.enable_vae_tiling()
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32,
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,
controlnet_adapter=True,
controlnet_target_content_blocks=controlnet_target_content_blocks,
controlnet_target_style_blocks=controlnet_target_style_blocks,
content_model_resampler=True,
style_model_resampler=True,
)
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def get_example():
case = [
[
"./assets/img_0.png",
'./assets/img_1.png',
"Image-Driven Style Transfer",
"there is a small house with a sheep statue on top of it",
1.0,
0.6,
1.0,
],
[
None,
'./assets/img_1.png',
"Text-Driven Style Synthesis",
"a cat",
1.0,
0.01,
1.0
],
[
None,
'./assets/img_2.png',
"Text-Driven Style Synthesis",
"a building",
0.5,
0.01,
1.0
],
[
"./assets/img_0.png",
'./assets/img_1.png',
"Text Edit-Driven Style Synthesis",
"there is a small house",
1.0,
0.4,
1.0
],
]
return case
def run_for_examples(content_image_pil,style_image_pil,target, prompt,scale_c_controlnet, scale_c, scale_s):
return create_image(
content_image_pil=content_image_pil,
style_image_pil=style_image_pil,
prompt=prompt,
scale_c_controlnet=scale_c_controlnet,
scale_c=scale_c,
scale_s=scale_s,
guidance_scale=7.0,
num_samples=3,
num_inference_steps=50,
seed=42,
target=target,
)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def create_image(content_image_pil,
style_image_pil,
prompt,
scale_c_controlnet,
scale_c,
scale_s,
guidance_scale,
num_samples,
num_inference_steps,
seed,
target="Image-Driven Style Transfer",
):
if content_image_pil is None:
content_image_pil = Image.fromarray(
np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
if prompt is None or prompt == '':
with torch.no_grad():
inputs = blip_processor(content_image_pil, return_tensors="pt").to(device)
out = blip_model.generate(**inputs)
prompt = blip_processor.decode(out[0], skip_special_tokens=True)
width, height, content_image = resize_content(content_image_pil)
style_image = style_image_pil
neg_content_prompt='text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry'
if target =="Image-Driven Style Transfer":
with torch.no_grad():
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
prompt=prompt,
negative_prompt=neg_content_prompt,
height=height,
width=width,
content_scale=scale_c,
style_scale=scale_s,
guidance_scale=guidance_scale,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
num_samples=1,
seed=seed,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=scale_c_controlnet,
)
elif target =="Text-Driven Style Synthesis":
content_image = Image.fromarray(
np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
with torch.no_grad():
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
prompt=prompt,
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
height=height,
width=width,
content_scale=scale_c,
style_scale=scale_s,
guidance_scale=7,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
num_samples=1,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=scale_c_controlnet,
)
elif target =="Text Edit-Driven Style Synthesis":
with torch.no_grad():
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
prompt=prompt,
negative_prompt=neg_content_prompt,
height=height,
width=width,
content_scale=scale_c,
style_scale=scale_s,
guidance_scale=guidance_scale,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
num_samples=1,
seed=seed,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=scale_c_controlnet,
)
return [image_grid(images, 1, num_samples)]
def pil_to_cv2(image_pil):
image_np = np.array(image_pil)
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
return image_cv2
# Description
title = r"""
<h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1>
"""
description = r"""
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br>
How to use:<br>
1. Upload a content image if you want to use image-driven style transfer.
2. Upload a style image.
3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are <b>Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis<b>.
4. <b>If you choose a text-driven task, enter your desired prompt<b>.
5.If you don't provide a prompt, the default is to use the BLIP model to generate the caption.
6. Click the <b>Submit</b> button to begin customization.
7. Share your stylized photo with your friends and enjoy! 😊
Advanced usage:<br>
1. Click advanced options.
2. Choose different guidance and steps.
"""
article = r"""
---
📝 **Tips**
In CSGO, the more accurate the text prompts for content images, the better the content retention.
Text-driven style synthesis and text-edit-driven style synthesis are expected to be more stable in the next release.
---
📝 **Citation**
<br>
If our work is helpful for your research or applications, please cite us via:
```bibtex
@article{xing2024csgo,
title={CSGO: Content-Style Composition in Text-to-Image Generation},
author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
year={2024},
journal = {arXiv 2408.16766},
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to open an issue or directly reach us out at <b>xingp_ng@njust.edu.cn</b>.
"""
block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
with block:
# description
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
content_image_pil = gr.Image(label="Content Image (optional)", type='pil')
style_image_pil = gr.Image(label="Style Image", type='pil')
target = gr.Radio(["Image-Driven Style Transfer", "Text-Driven Style Synthesis", "Text Edit-Driven Style Synthesis"],
value="Image-Driven Style Transfer",
label="task")
prompt = gr.Textbox(label="Prompt",
value="there is a small house with a sheep statue on top of it")
scale_c_controlnet = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6,
label="Content Scale for controlnet")
scale_c = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6, label="Content Scale for IPA")
scale_s = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=1.0, label="Style Scale")
with gr.Accordion(open=False, label="Advanced Options"):
guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale")
num_samples = gr.Slider(minimum=1, maximum=4.0, step=1.0, value=1.0, label="num samples")
num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50,
label="num inference steps")
seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value")
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Image")
with gr.Column():
generated_image = gr.Gallery(label="Generated Image")
generate_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=create_image,
inputs=[content_image_pil,
style_image_pil,
prompt,
scale_c_controlnet,
scale_c,
scale_s,
guidance_scale,
num_samples,
num_inference_steps,
seed,
target,],
outputs=[generated_image])
gr.Examples(
examples=get_example(),
inputs=[content_image_pil,style_image_pil,target, prompt,scale_c_controlnet, scale_c, scale_s],
fn=run_for_examples,
outputs=[generated_image],
cache_examples=True,
)
gr.Markdown(article)
block.launch(server_name="0.0.0.0", server_port=1234)
================================================
FILE: gradio/requirements.txt
================================================
diffusers==0.25.1
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
transformers==4.40.2
accelerate
safetensors
einops
spaces==0.19.4
omegaconf
peft
huggingface-hub==0.24.5
opencv-python
insightface
gradio
controlnet_aux
gdown
peft
================================================
FILE: infer/infer_CSGO.py
================================================
import os
os.environ['HF_ENDPOINT']='https://hf-mirror.com'
import torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from ip_adapter.utils import resize_content
import cv2
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
)
from ip_adapter import CSGO
from transformers import BlipProcessor, BlipForConditionalGeneration
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
base_model_path = "../../../base_models/stable-diffusion-xl-base-1.0"
image_encoder_path = "../../../base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "/share2/xingpeng/DATA/blora/outputs/content_style_checkpoints_2/base_train_free_controlnet_S12_alldata_C_0_S_I_zero2_style_res_32_content_res4_drop/checkpoint-220000/ip_adapter.bin"
pretrained_vae_name_or_path ='../../../base_models/sdxl-vae-fp16-fix'
controlnet_path = "../../../base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
add_watermarker=False,
vae=vae
)
pipe.enable_vae_tiling()
target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet_adapter=True,
controlnet_target_content_blocks=controlnet_target_content_blocks,
controlnet_target_style_blocks=controlnet_target_style_blocks,
content_model_resampler=True,
style_model_resampler=True,
)
style_name = 'img_1.png'
content_name = 'img_0.png'
style_image = Image.open("../assets/{}".format(style_name)).convert('RGB')
content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')
with torch.no_grad():
inputs = blip_processor(content_image, return_tensors="pt").to(device)
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
num_sample=1
width,height,content_image = resize_content(content_image)
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
height=height,
width=width,
content_scale=0.5,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.6,
)
images[0].save("../assets/content_img_0_style_imag_1.png")
================================================
FILE: infer/infer_csgo.ipynb
================================================
[File too large to display: 25.9 MB]
================================================
FILE: ip_adapter/__init__.py
================================================
from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS
from .ip_adapter import CSGO
__all__ = [
"IPAdapter",
"IPAdapterPlus",
"IPAdapterPlusXL",
"IPAdapterXL",
"CSGO"
"IPAdapterFull",
]
================================================
FILE: ip_adapter/attention_processor.py
================================================
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
save_in_unet='down',
atten_control=None,
):
super().__init__()
self.atten_control = atten_control
self.save_in_unet = save_in_unet
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.skip = skip
self.atten_control = atten_control
self.save_in_unet = save_in_unet
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if not self.skip:
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
save_in_unet='down',
atten_control=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.atten_control = atten_control
self.save_in_unet = save_in_unet
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.skip = skip
self.atten_control = atten_control
self.save_in_unet = save_in_unet
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if not self.skip:
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IP_CS_AttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
skip=False,content=False, style=False):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.content_scale = content_scale
self.style_scale = style_scale
self.num_content_tokens = num_content_tokens
self.num_style_tokens = num_style_tokens
self.skip = skip
self.content = content
self.style = style
if self.content or self.style:
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_k_ip_content =None
self.to_v_ip_content =None
def set_content_ipa(self,content_scale=1.0):
self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
self.content_scale=content_scale
self.content =True
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :],
encoder_hidden_states[:, end_pos + self.num_content_tokens:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if not self.skip and self.content is True:
# print('content#####################################################')
# for ip-content-adapter
if self.to_k_ip_content is None:
ip_content_key = self.to_k_ip(ip_content_hidden_states)
ip_content_value = self.to_v_ip(ip_content_hidden_states)
else:
ip_content_key = self.to_k_ip_content(ip_content_hidden_states)
ip_content_value = self.to_v_ip_content(ip_content_hidden_states)
ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_content_hidden_states = F.scaled_dot_product_attention(
query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_content_hidden_states = ip_content_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.content_scale * ip_content_hidden_states
if not self.skip and self.style is True:
# for ip-style-adapter
ip_style_key = self.to_k_ip(ip_style_hidden_states)
ip_style_value = self.to_v_ip(ip_style_hidden_states)
ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_style_hidden_states = F.scaled_dot_product_attention(
query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
attn.heads * head_dim)
ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
## for controlnet
class CNAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):
self.num_tokens = num_tokens
self.atten_control = atten_control
self.save_in_unet = save_in_unet
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CNAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.num_tokens = num_tokens
self.atten_control = atten_control
self.save_in_unet = save_in_unet
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
================================================
FILE: ip_adapter/ip_adapter.py
================================================
import os
from typing import List
import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from torchvision import transforms
from .utils import is_torch2_available, get_generator
# import torchvision.transforms.functional as Func
# from .clip_style_models import CSD_CLIP, convert_state_dict
if is_torch2_available():
from .attention_processor import (
AttnProcessor2_0 as AttnProcessor,
)
from .attention_processor import (
CNAttnProcessor2_0 as CNAttnProcessor,
)
from .attention_processor import (
IPAttnProcessor2_0 as IPAttnProcessor,
)
from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
else:
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .resampler import Resampler
from transformers import AutoImageProcessor, AutoModel
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
# print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim)
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim)
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.target_blocks = target_blocks
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# image proj model
self.image_proj_model = self.init_proj()
self.load_ip_adapter()
def init_proj(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
selected = False
for block_name in self.target_blocks:
if block_name in name:
selected = True
break
if selected:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.num_tokens,
skip=True
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if isinstance(self.pipe.controlnet, MultiControlNetModel):
for controlnet in self.pipe.controlnet.nets:
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
else:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
if content_prompt_embeds is not None:
clip_image_embeds = clip_image_embeds - content_prompt_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
def generate(
self,
pil_image=None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
neg_content_emb=None,
**kwargs,
):
self.set_scale(scale)
if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
else:
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapter_CS:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,
num_style_tokens=4,
target_content_blocks=["block"], target_style_blocks=["block"], content_image_encoder_path=None,
controlnet_adapter=False,
controlnet_target_content_blocks=None,
controlnet_target_style_blocks=None,
content_model_resampler=False,
style_model_resampler=False,
):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_content_tokens = num_content_tokens
self.num_style_tokens = num_style_tokens
self.content_target_blocks = target_content_blocks
self.style_target_blocks = target_style_blocks
self.content_model_resampler = content_model_resampler
self.style_model_resampler = style_model_resampler
self.controlnet_adapter = controlnet_adapter
self.controlnet_target_content_blocks = controlnet_target_content_blocks
self.controlnet_target_style_blocks = controlnet_target_style_blocks
self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()
self.content_image_encoder_path = content_image_encoder_path
# load image encoder
if content_image_encoder_path is not None:
self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device,
dtype=torch.float16)
self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path)
else:
self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.content_image_processor = CLIPImageProcessor()
# model.requires_grad_(False)
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
# if self.use_CSD is not None:
# self.style_image_encoder = CSD_CLIP("vit_large", "default",self.use_CSD+"/ViT-L-14.pt")
# model_path = self.use_CSD+"/checkpoint.pth"
# checkpoint = torch.load(model_path, map_location="cpu")
# state_dict = convert_state_dict(checkpoint['model_state_dict'])
# self.style_image_encoder.load_state_dict(state_dict, strict=False)
#
# normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# self.style_preprocess = transforms.Compose([
# transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# normalize,
# ])
self.clip_image_processor = CLIPImageProcessor()
# image proj model
self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content',
model_resampler=self.content_model_resampler)
self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',
model_resampler=self.style_model_resampler)
self.load_ip_adapter()
def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
# print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim)
if content_or_style_ == 'content' and self.content_image_encoder_path is not None:
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.content_image_encoder.config.projection_dim,
clip_extra_context_tokens=num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
# layername_id += 1
selected = False
for block_name in self.style_target_blocks:
if block_name in name:
selected = True
# print(name)
attn_procs[name] = IP_CS_AttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
style_scale=1.0,
style=True,
num_content_tokens=self.num_content_tokens,
num_style_tokens=self.num_style_tokens,
)
for block_name in self.content_target_blocks:
if block_name in name:
# selected = True
if selected is False:
attn_procs[name] = IP_CS_AttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
content_scale=1.0,
content=True,
num_content_tokens=self.num_content_tokens,
num_style_tokens=self.num_style_tokens,
)
else:
attn_procs[name].set_content_ipa(content_scale=1.0)
# attn_procs[name].content=True
if selected is False:
attn_procs[name] = IP_CS_AttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
num_content_tokens=self.num_content_tokens,
num_style_tokens=self.num_style_tokens,
skip=True,
)
attn_procs[name].to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if self.controlnet_adapter is False:
if isinstance(self.pipe.controlnet, MultiControlNetModel):
for controlnet in self.pipe.controlnet.nets:
controlnet.set_attn_processor(CNAttnProcessor(
num_tokens=self.num_content_tokens + self.num_style_tokens))
else:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
num_tokens=self.num_content_tokens + self.num_style_tokens))
else:
controlnet_attn_procs = {}
controlnet_style_target_blocks = self.controlnet_target_style_blocks
controlnet_content_target_blocks = self.controlnet_target_content_blocks
for name in self.pipe.controlnet.attn_processors.keys():
# print(name)
cross_attention_dim = None if name.endswith(
"attn1.processor") else self.pipe.controlnet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.pipe.controlnet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.pipe.controlnet.config.block_out_channels[block_id]
if cross_attention_dim is None:
# layername_id += 1
controlnet_attn_procs[name] = AttnProcessor()
else:
# layername_id += 1
selected = False
for block_name in controlnet_style_target_blocks:
if block_name in name:
selected = True
# print(name)
controlnet_attn_procs[name] = IP_CS_AttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
style_scale=1.0,
style=True,
num_content_tokens=self.num_content_tokens,
num_style_tokens=self.num_style_tokens,
)
for block_name in controlnet_content_target_blocks:
if block_name in name:
if selected is False:
controlnet_attn_procs[name] = IP_CS_AttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
content_scale=1.0,
content=True,
num_content_tokens=self.num_content_tokens,
num_style_tokens=self.num_style_tokens,
)
selected = True
elif selected is True:
controlnet_attn_procs[name].set_content_ipa(content_scale=1.0)
# if args.content_image_encoder_type !='dinov2':
# weights = {
# "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
# "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
# }
# attn_procs[name].load_state_dict(weights)
if selected is False:
controlnet_attn_procs[name] = IP_CS_AttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
num_content_tokens=self.num_content_tokens,
num_style_tokens=self.num_style_tokens,
skip=True,
)
controlnet_attn_procs[name].to(self.device, dtype=torch.float16)
# layer_name = name.split(".processor")[0]
# # print(state_dict["ip_adapter"].keys())
# weights = {
# "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
# "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
# }
# attn_procs[name].load_state_dict(weights)
self.pipe.controlnet.set_attn_processor(controlnet_attn_procs)
def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("content_image_proj."):
state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key)
elif key.startswith("style_image_proj."):
state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.content_image_proj_model.load_state_dict(state_dict["content_image_proj"])
self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"])
if 'conv_in_unet_sd' in state_dict.keys():
self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True)
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
if self.controlnet_adapter is True:
print('loading controlnet_adapter')
self.pipe.controlnet.load_state_dict(state_dict["controlnet_adapter_modules"], strict=False)
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None,
content_or_style_=''):
# if pil_image is not None:
# if isinstance(pil_image, Image.Image):
# pil_image = [pil_image]
# clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
# clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
# else:
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
# if content_prompt_embeds is not None:
# clip_image_embeds = clip_image_embeds - content_prompt_embeds
if content_or_style_ == 'content':
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
if self.content_image_proj_model is not None:
clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.content_image_encoder(
clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
if content_or_style_ == 'style':
if pil_image is not None:
if self.use_CSD is not None:
clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32)
clip_image_embeds = self.style_image_encoder(clip_image)
else:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, content_scale, style_scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IP_CS_AttnProcessor):
if attn_processor.content is True:
attn_processor.content_scale = content_scale
if attn_processor.style is True:
attn_processor.style_scale = style_scale
# print('style_scale:',style_scale)
if self.controlnet_adapter is not None:
for attn_processor in self.pipe.controlnet.attn_processors.values():
if isinstance(attn_processor, IP_CS_AttnProcessor):
if attn_processor.content is True:
attn_processor.content_scale = content_scale
# print(content_scale)
if attn_processor.style is True:
attn_processor.style_scale = style_scale
def generate(
self,
pil_content_image=None,
pil_style_image=None,
clip_content_image_embeds=None,
clip_style_image_embeds=None,
prompt=None,
negative_prompt=None,
content_scale=1.0,
style_scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
neg_content_emb=None,
**kwargs,
):
self.set_scale(content_scale, style_scale)
if pil_content_image is not None:
num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
else:
num_prompts = clip_content_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds
)
style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds
)
bs_embed, seq_len, _ = content_image_prompt_embeds.shape
content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
-1)
bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape
style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
-1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_,
uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterXL_CS(IPAdapter_CS):
"""SDXL"""
def generate(
self,
pil_content_image,
pil_style_image,
prompt=None,
negative_prompt=None,
content_scale=1.0,
style_scale=1.0,
num_samples=4,
seed=None,
content_image_embeds=None,
style_image_embeds=None,
num_inference_steps=30,
neg_content_emb=None,
neg_content_prompt=None,
neg_content_scale=1.0,
**kwargs,
):
self.set_scale(content_scale, style_scale)
num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image,
content_image_embeds,
content_or_style_='content')
style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image,
style_image_embeds,
content_or_style_='style')
bs_embed, seq_len, _ = content_image_prompt_embeds.shape
content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
-1)
bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape
style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
-1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds,
uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
dim=1)
self.generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=self.generator,
**kwargs,
).images
return images
class CSGO(IPAdapterXL_CS):
"""SDXL"""
def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
if content_or_style_ == 'content':
if model_resampler:
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=num_tokens,
embedding_dim=self.content_image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
else:
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=num_tokens,
).to(self.device, dtype=torch.float16)
if content_or_style_ == 'style':
if model_resampler:
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=num_tokens,
embedding_dim=self.content_image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
else:
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
if content_or_style_ == 'style':
if self.style_model_resampler:
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
else:
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
else:
if self.content_image_encoder_path is not None:
clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16),
output_hidden_states=True)
clip_image_embeds = outputs.last_hidden_state
image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
# uncond_clip_image_embeds = self.image_encoder(
# torch.zeros_like(clip_image), output_hidden_states=True
# ).last_hidden_state
uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
else:
if self.content_model_resampler:
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
# uncond_clip_image_embeds = self.image_encoder(
# torch.zeros_like(clip_image), output_hidden_states=True
# ).hidden_states[-2]
uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
else:
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
# # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
# clip_image = clip_image.to(self.device, dtype=torch.float16)
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
# image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
# uncond_clip_image_embeds = self.image_encoder(
# torch.zeros_like(clip_image), output_hidden_states=True
# ).hidden_states[-2]
# uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds)
# return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterXL(IPAdapter):
"""SDXL"""
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
neg_content_emb=None,
neg_content_prompt=None,
neg_content_scale=1.0,
**kwargs,
):
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
if neg_content_emb is None:
if neg_content_prompt is not None:
with torch.inference_mode():
(
prompt_embeds_, # torch.Size([1, 77, 2048])
negative_prompt_embeds_,
pooled_prompt_embeds_, # torch.Size([1, 1280])
negative_pooled_prompt_embeds_,
) = self.pipe.encode_prompt(
neg_content_prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
pooled_prompt_embeds_ *= neg_content_scale
else:
pooled_prompt_embeds_ = neg_content_emb
else:
pooled_prompt_embeds_ = None
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image,
content_prompt_embeds=pooled_prompt_embeds_)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
self.generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=self.generator,
**kwargs,
).images
return images
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def init_proj(self):
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter with full features"""
def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.hidden_size,
).to(self.device, dtype=torch.float16)
return image_proj_model
class IPAdapterPlusXL(IPAdapter):
"""SDXL"""
def init_proj(self):
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
================================================
FILE: ip_adapter/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
================================================
FILE: ip_adapter/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
BLOCKS = {
'content': ['down_blocks'],
'style': ["up_blocks"],
}
controlnet_BLOCKS = {
'content': [],
'style': ["down_blocks"],
}
def resize_width_height(width, height, min_short_side=512, max_long_side=1024):
if width < height:
if width < min_short_side:
scale_factor = min_short_side / width
new_width = min_short_side
new_height = int(height * scale_factor)
else:
new_width, new_height = width, height
else:
if height < min_short_side:
scale_factor = min_short_side / height
new_width = int(width * scale_factor)
new_height = min_short_side
else:
new_width, new_height = width, height
if max(new_width, new_height) > max_long_side:
scale_factor = max_long_side / max(new_width, new_height)
new_width = int(new_width * scale_factor)
new_height = int(new_height * scale_factor)
return new_width, new_height
def resize_content(content_image):
max_long_side = 1024
min_short_side = 1024
new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],
min_short_side=min_short_side, max_long_side=max_long_side)
height = new_height // 16 * 16
width = new_width // 16 * 16
content_image = content_image.resize((width, height))
return width,height,content_image
attn_maps = {}
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if name.split('.')[-1].startswith('attn2'):
module.register_forward_hook(hook_fn(name))
return unet
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0)
attn_map = attn_map.permute(1,0)
temp_size = None
for i in range(0,5):
scale = 2 ** i
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
break
assert temp_size is not None, "temp_size cannot is None"
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
attn_map = F.interpolate(
attn_map.unsqueeze(0).to(dtype=torch.float32),
size=target_size,
mode='bilinear',
align_corners=False
)[0]
attn_map = torch.softmax(attn_map, dim=0)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
attn_map = upscale(attn_map, image_size)
net_attn_maps.append(attn_map)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
return net_attn_maps
def attnmaps2images(net_attn_maps):
#total_attn_scores = 0
images = []
for attn_map in net_attn_maps:
attn_map = attn_map.cpu().numpy()
#total_attn_scores += attn_map.mean().item()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
#print("norm: ", normalized_attn_map.shape)
image = Image.fromarray(normalized_attn_map)
#image = fix_save_attn_map(attn_map)
images.append(image)
#print(total_attn_scores)
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed, device):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(device).manual_seed(seed)
else:
generator = None
return generator
================================================
FILE: requirements.txt
================================================
diffusers==0.25.1
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
transformers==4.40.2
accelerate
safetensors
einops
spaces==0.19.4
omegaconf
peft
huggingface-hub==0.24.5
opencv-python
insightface
gradio
controlnet_aux
gdown
peft
gitextract_n_50hphl/ ├── README.md ├── gradio/ │ ├── app.py │ └── requirements.txt ├── infer/ │ ├── infer_CSGO.py │ └── infer_csgo.ipynb ├── ip_adapter/ │ ├── __init__.py │ ├── attention_processor.py │ ├── ip_adapter.py │ ├── resampler.py │ └── utils.py └── requirements.txt
SYMBOL INDEX (85 symbols across 5 files)
FILE: gradio/app.py
function randomize_seed_fn (line 64) | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
function get_example (line 73) | def get_example():
function run_for_examples (line 115) | def run_for_examples(content_image_pil,style_image_pil,target, prompt,sc...
function randomize_seed_fn (line 129) | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
function image_grid (line 134) | def image_grid(imgs, rows, cols):
function create_image (line 144) | def create_image(content_image_pil,
function pil_to_cv2 (line 228) | def pil_to_cv2(image_pil):
FILE: ip_adapter/attention_processor.py
class AttnProcessor (line 7) | class AttnProcessor(nn.Module):
method __init__ (line 12) | def __init__(
method __call__ (line 23) | def __call__(
class IPAttnProcessor (line 84) | class IPAttnProcessor(nn.Module):
method __init__ (line 98) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
method __call__ (line 113) | def __call__(
class AttnProcessor2_0 (line 196) | class AttnProcessor2_0(torch.nn.Module):
method __init__ (line 201) | def __init__(
method __call__ (line 214) | def __call__(
class IPAttnProcessor2_0 (line 289) | class IPAttnProcessor2_0(torch.nn.Module):
method __init__ (line 303) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
method __call__ (line 321) | def __call__(
class IP_CS_AttnProcessor2_0 (line 425) | class IP_CS_AttnProcessor2_0(torch.nn.Module):
method __init__ (line 439) | def __init__(self, hidden_size, cross_attention_dim=None, content_scal...
method set_content_ipa (line 463) | def set_content_ipa(self,content_scale=1.0):
method __call__ (line 470) | def __call__(
class CNAttnProcessor (line 600) | class CNAttnProcessor:
method __init__ (line 605) | def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):
method __call__ (line 610) | def __call__(self, attn, hidden_states, encoder_hidden_states=None, at...
class CNAttnProcessor2_0 (line 667) | class CNAttnProcessor2_0:
method __init__ (line 672) | def __init__(self, num_tokens=4, save_in_unet='down', atten_control=No...
method __call__ (line 679) | def __call__(
FILE: ip_adapter/ip_adapter.py
class ImageProjModel (line 35) | class ImageProjModel(torch.nn.Module):
method __init__ (line 38) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
method forward (line 48) | def forward(self, image_embeds):
class MLPProjModel (line 57) | class MLPProjModel(torch.nn.Module):
method __init__ (line 60) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
method forward (line 70) | def forward(self, image_embeds):
class IPAdapter (line 75) | class IPAdapter:
method __init__ (line 76) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_t...
method init_proj (line 96) | def init_proj(self):
method set_ip_adapter (line 104) | def set_ip_adapter(self):
method load_ip_adapter (line 148) | def load_ip_adapter(self):
method get_image_embeds (line 164) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, con...
method set_scale (line 180) | def set_scale(self, scale):
method generate (line 185) | def generate(
class IPAdapter_CS (line 250) | class IPAdapter_CS:
method __init__ (line 251) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_c...
method init_proj (line 319) | def init_proj(self, num_tokens, content_or_style_='content', model_res...
method set_ip_adapter (line 337) | def set_ip_adapter(self):
method load_ip_adapter (line 480) | def load_ip_adapter(self):
method get_image_embeds (line 506) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, con...
method set_scale (line 554) | def set_scale(self, content_scale, style_scale):
method generate (line 574) | def generate(
class IPAdapterXL_CS (line 656) | class IPAdapterXL_CS(IPAdapter_CS):
method generate (line 659) | def generate(
class CSGO (line 747) | class CSGO(IPAdapterXL_CS):
method init_proj (line 750) | def init_proj(self, num_tokens, content_or_style_='content', model_res...
method get_image_embeds (line 790) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None, con...
class IPAdapterXL (line 858) | class IPAdapterXL(IPAdapter):
method generate (line 861) | def generate(
class IPAdapterPlus (line 947) | class IPAdapterPlus(IPAdapter):
method init_proj (line 950) | def init_proj(self):
method get_image_embeds (line 964) | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
class IPAdapterFull (line 978) | class IPAdapterFull(IPAdapterPlus):
method init_proj (line 981) | def init_proj(self):
class IPAdapterPlusXL (line 989) | class IPAdapterPlusXL(IPAdapter):
method init_proj (line 992) | def init_proj(self):
method get_image_embeds (line 1006) | def get_image_embeds(self, pil_image):
method generate (line 1019) | def generate(
FILE: ip_adapter/resampler.py
function FeedForward (line 13) | def FeedForward(dim, mult=4):
function reshape_tensor (line 23) | def reshape_tensor(x, heads):
class PerceiverAttention (line 34) | class PerceiverAttention(nn.Module):
method __init__ (line 35) | def __init__(self, *, dim, dim_head=64, heads=8):
method forward (line 49) | def forward(self, x, latents):
class Resampler (line 81) | class Resampler(nn.Module):
method __init__ (line 82) | def __init__(
method forward (line 127) | def forward(self, x):
function masked_mean (line 150) | def masked_mean(t, *, dim, mask=None):
FILE: ip_adapter/utils.py
function resize_width_height (line 18) | def resize_width_height(width, height, min_short_side=512, max_long_side...
function resize_content (line 43) | def resize_content(content_image):
function hook_fn (line 56) | def hook_fn(name):
function register_cross_attention_hook (line 64) | def register_cross_attention_hook(unet):
function upscale (line 71) | def upscale(attn_map, target_size):
function get_net_attn_map (line 95) | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=Fals...
function attnmaps2images (line 110) | def attnmaps2images(net_attn_maps):
function is_torch2_available (line 129) | def is_torch2_available():
function get_generator (line 132) | def get_generator(seed, device):
Condensed preview — 11 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (125K chars).
[
{
"path": "README.md",
"chars": 12164,
"preview": "<div align=\"center\">\n<h1>CSGO: Content-Style Composition in Text-to-Image Generation</h1>\n\n[**Peng Xing**](https://githu"
},
{
"path": "gradio/app.py",
"chars": 13644,
"preview": "import sys\n# sys.path.append(\"../\")\nsys.path.append(\"./\")\nimport gradio as gr\nimport torch\nfrom ip_adapter.utils import "
},
{
"path": "gradio/requirements.txt",
"chars": 233,
"preview": "diffusers==0.25.1\ntorch==2.0.1\ntorchaudio==2.0.2\ntorchvision==0.15.2\ntransformers==4.40.2\naccelerate\nsafetensors\neinops\n"
},
{
"path": "infer/infer_CSGO.py",
"chars": 3913,
"preview": "import os\nos.environ['HF_ENDPOINT']='https://hf-mirror.com'\nimport torch\nfrom ip_adapter.utils import BLOCKS as BLOCKS\nf"
},
{
"path": "ip_adapter/__init__.py",
"chars": 277,
"preview": "from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_C"
},
{
"path": "ip_adapter/attention_processor.py",
"chars": 29611,
"preview": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport to"
},
{
"path": "ip_adapter/ip_adapter.py",
"chars": 51853,
"preview": "import os\nfrom typing import List\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.c"
},
{
"path": "ip_adapter/resampler.py",
"chars": 5059,
"preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
},
{
"path": "ip_adapter/utils.py",
"chars": 4287,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\nBLOCKS = {\n 'content': ['down_"
},
{
"path": "requirements.txt",
"chars": 233,
"preview": "diffusers==0.25.1\ntorch==2.0.1\ntorchaudio==2.0.2\ntorchvision==0.15.2\ntransformers==4.40.2\naccelerate\nsafetensors\neinops\n"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the instantX-research/CSGO GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 11 files (26.0 MB), approximately 26.8k tokens, and a symbol index with 85 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.