Repository: junjiehe96/UniPortrait
Branch: main
Commit: a4deff2b48e3
Files: 16
Total size: 113.0 KB
Directory structure:
gitextract_35y_e2yv/
├── .gitignore
├── LICENSE.txt
├── README.md
├── gradio_app.py
├── requirements.txt
└── uniportrait/
├── __init__.py
├── curricular_face/
│ ├── __init__.py
│ ├── backbone/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── model_irse.py
│ │ └── model_resnet.py
│ └── inference.py
├── inversion.py
├── resampler.py
├── uniportrait_attention_processor.py
└── uniportrait_pipeline.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea/
.DS_Store
*.dat
*.mat
training/
lightning_logs/
image_log/
*.png
*.jpg
*.jpeg
*.webp
*.pth
*.pt
*.ckpt
*.safetensors
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE.txt
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
<h1>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization</h1>
<a href='https://aigcdesigngroup.github.io/UniPortrait-Page/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://arxiv.org/abs/2408.05939'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
<a href='https://huggingface.co/spaces/Junjie96/UniPortrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
</div>
<img src='assets/highlight.png'>
UniPortrait is an innovative human image personalization framework. It customizes single- and multi-ID images in a
unified manner, providing high-fidelity identity preservation, extensive facial editability, free-form text description,
and no requirement for a predetermined layout.
---
## Release
- [2025/05/01] 🔥 We release the code and demo for the `FLUX.1-dev` version of [AnyStory](https://github.com/junjiehe96/AnyStory), a unified approach to general subject personalization.
- [2024/10/18] 🔥 We release the inference code and demo, which has simply
integrated [ControlNet](https://github.com/lllyasviel/ControlNet)
, [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter),
and [StyleAligned](https://github.com/google/style-aligned). The weight for this version is consistent with the
huggingface space and experiments in the paper. We are now working on generalizing our method to more advanced
diffusion models and more general custom concepts. Please stay tuned!
- [2024/08/12] 🔥 We release the [technical report](https://arxiv.org/abs/2408.05939)
, [project page](https://aigcdesigngroup.github.io/UniPortrait-Page/),
and [HuggingFace demo](https://huggingface.co/spaces/Junjie96/UniPortrait) 🤗!
## Quickstart
```shell
# Clone repository
git clone https://github.com/junjiehe96/UniPortrait.git
# install requirements
cd UniPortrait
pip install -r requirements.txt
# download the models
git lfs install
git clone https://huggingface.co/Junjie96/UniPortrait models
# download ip-adapter models
# Note: recommend downloading manually. We do not require all IP adapter models.
git clone https://huggingface.co/h94/IP-Adapter models/IP-Adapter
# then you can use the gradio app
python gradio_app.py
```
## Applications
<img src='assets/application.png'>
## **Acknowledgements**
This code is built on some excellent repos, including [diffusers](https://github.com/huggingface/diffusers), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [StyleAligned](https://github.com/google/style-aligned). Highly appreciate their great work!
## Cite
If you find UniPortrait useful for your research and applications, please cite us using this BibTeX:
```bibtex
@article{he2024uniportrait,
title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},
author={He, Junjie and Geng, Yifeng and Bo, Liefeng},
journal={arXiv preprint arXiv:2408.05939},
year={2024}
}
```
For any question, please feel free to open an issue or contact us via hejunjie1103@gmail.com.
================================================
FILE: gradio_app.py
================================================
import os
from io import BytesIO
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from diffusers import DDIMScheduler, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
from insightface.app import FaceAnalysis
from insightface.utils import face_align
from uniportrait import inversion
from uniportrait.uniportrait_attention_processor import attn_args
from uniportrait.uniportrait_pipeline import UniPortraitPipeline
port = 7860
device = "cuda"
torch_dtype = torch.float16
# base
base_model_path = "SG161222/Realistic_Vision_V5.1_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
controlnet_pose_ckpt = "lllyasviel/control_v11p_sd15_openpose"
# specific
image_encoder_path = "models/IP-Adapter/models/image_encoder"
ip_ckpt = "models/IP-Adapter/models/ip-adapter_sd15.bin"
face_backbone_ckpt = "models/glint360k_curricular_face_r101_backbone.bin"
uniportrait_faceid_ckpt = "models/uniportrait-faceid_sd15.bin"
uniportrait_router_ckpt = "models/uniportrait-router_sd15.bin"
# load controlnet
pose_controlnet = ControlNetModel.from_pretrained(controlnet_pose_ckpt, torch_dtype=torch_dtype)
# load SD pipeline
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path,
controlnet=[pose_controlnet],
torch_dtype=torch_dtype,
scheduler=noise_scheduler,
vae=vae,
# feature_extractor=None,
# safety_checker=None,
)
# load uniportrait pipeline
uniportrait_pipeline = UniPortraitPipeline(pipe, image_encoder_path, ip_ckpt=ip_ckpt,
face_backbone_ckpt=face_backbone_ckpt,
uniportrait_faceid_ckpt=uniportrait_faceid_ckpt,
uniportrait_router_ckpt=uniportrait_router_ckpt,
device=device, torch_dtype=torch_dtype)
# load face detection assets
face_app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=["detection"])
face_app.prepare(ctx_id=0, det_size=(640, 640))
def pad_np_bgr_image(np_image, scale=1.25):
assert scale >= 1.0, "scale should be >= 1.0"
pad_scale = scale - 1.0
h, w = np_image.shape[:2]
top = bottom = int(h * pad_scale)
left = right = int(w * pad_scale)
ret = cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128))
return ret, (left, top)
def process_faceid_image(pil_faceid_image):
np_faceid_image = np.array(pil_faceid_image.convert("RGB"))
img = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
faces = face_app.get(img) # bgr
if len(faces) == 0:
# padding, try again
_h, _w = img.shape[:2]
_img, left_top_coord = pad_np_bgr_image(img)
faces = face_app.get(_img)
if len(faces) == 0:
gr.Info("Warning: No face detected in the image. Continue processing...")
min_coord = np.array([0, 0])
max_coord = np.array([_w, _h])
sub_coord = np.array([left_top_coord[0], left_top_coord[1]])
for face in faces:
face.bbox = np.minimum(np.maximum(face.bbox.reshape(-1, 2) - sub_coord, min_coord), max_coord).reshape(4)
face.kps = face.kps - sub_coord
faces = sorted(faces, key=lambda x: abs((x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])), reverse=True)
faceid_face = faces[0]
norm_face = face_align.norm_crop(img, landmark=faceid_face.kps, image_size=224)
pil_faceid_align_image = Image.fromarray(cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB))
return pil_faceid_align_image
def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_supp_images=None,
pil_faceid_mix_images=None, mix_scales=None):
pil_faceid_align_images = []
if pil_faceid_image:
pil_faceid_align_images.append(process_faceid_image(pil_faceid_image))
if pil_faceid_supp_images and len(pil_faceid_supp_images) > 0:
for pil_faceid_supp_image in pil_faceid_supp_images:
if isinstance(pil_faceid_supp_image, Image.Image):
pil_faceid_align_images.append(process_faceid_image(pil_faceid_supp_image))
else:
pil_faceid_align_images.append(
process_faceid_image(Image.open(BytesIO(pil_faceid_supp_image)))
)
mix_refs = []
mix_ref_scales = []
if pil_faceid_mix_images:
for pil_faceid_mix_image, mix_scale in zip(pil_faceid_mix_images, mix_scales):
if pil_faceid_mix_image:
mix_refs.append(process_faceid_image(pil_faceid_mix_image))
mix_ref_scales.append(mix_scale)
single_faceid_cond_kwargs = None
if len(pil_faceid_align_images) > 0:
single_faceid_cond_kwargs = {
"refs": pil_faceid_align_images
}
if len(mix_refs) > 0:
single_faceid_cond_kwargs["mix_refs"] = mix_refs
single_faceid_cond_kwargs["mix_scales"] = mix_ref_scales
return single_faceid_cond_kwargs
def text_to_single_id_generation_process(
pil_faceid_image=None, pil_faceid_supp_images=None,
pil_faceid_mix_image_1=None, mix_scale_1=0.0,
pil_faceid_mix_image_2=None, mix_scale_2=0.0,
faceid_scale=0.0, face_structure_scale=0.0,
prompt="", negative_prompt="",
num_samples=1, seed=-1,
image_resolution="512x512",
inference_steps=25,
):
if seed == -1:
seed = None
single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,
pil_faceid_supp_images,
[pil_faceid_mix_image_1, pil_faceid_mix_image_2],
[mix_scale_1, mix_scale_2])
cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []
# reset attn args
attn_args.reset()
# set faceid condition
attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora
attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora
attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
attn_args.num_faceids = len(cond_faceids)
print(attn_args)
h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
prompt = [prompt] * num_samples
negative_prompt = [negative_prompt] * num_samples
images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
seed=seed, guidance_scale=7.5,
num_inference_steps=inference_steps,
image=[torch.zeros([1, 3, h, w])],
controlnet_conditioning_scale=[0.0])
final_out = []
for pil_image in images:
final_out.append(pil_image)
for single_faceid_cond_kwargs in cond_faceids:
final_out.extend(single_faceid_cond_kwargs["refs"])
if "mix_refs" in single_faceid_cond_kwargs:
final_out.extend(single_faceid_cond_kwargs["mix_refs"])
return final_out
def text_to_multi_id_generation_process(
pil_faceid_image_1=None, pil_faceid_supp_images_1=None,
pil_faceid_mix_image_1_1=None, mix_scale_1_1=0.0,
pil_faceid_mix_image_1_2=None, mix_scale_1_2=0.0,
pil_faceid_image_2=None, pil_faceid_supp_images_2=None,
pil_faceid_mix_image_2_1=None, mix_scale_2_1=0.0,
pil_faceid_mix_image_2_2=None, mix_scale_2_2=0.0,
faceid_scale=0.0, face_structure_scale=0.0,
prompt="", negative_prompt="",
num_samples=1, seed=-1,
image_resolution="512x512",
inference_steps=25,
):
if seed == -1:
seed = None
faceid_cond_kwargs_1 = prepare_single_faceid_cond_kwargs(pil_faceid_image_1,
pil_faceid_supp_images_1,
[pil_faceid_mix_image_1_1,
pil_faceid_mix_image_1_2],
[mix_scale_1_1, mix_scale_1_2])
faceid_cond_kwargs_2 = prepare_single_faceid_cond_kwargs(pil_faceid_image_2,
pil_faceid_supp_images_2,
[pil_faceid_mix_image_2_1,
pil_faceid_mix_image_2_2],
[mix_scale_2_1, mix_scale_2_2])
cond_faceids = []
if faceid_cond_kwargs_1 is not None:
cond_faceids.append(faceid_cond_kwargs_1)
if faceid_cond_kwargs_2 is not None:
cond_faceids.append(faceid_cond_kwargs_2)
# reset attn args
attn_args.reset()
# set faceid condition
attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora
attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora
attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
attn_args.num_faceids = len(cond_faceids)
print(attn_args)
h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
prompt = [prompt] * num_samples
negative_prompt = [negative_prompt] * num_samples
images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
seed=seed, guidance_scale=7.5,
num_inference_steps=inference_steps,
image=[torch.zeros([1, 3, h, w])],
controlnet_conditioning_scale=[0.0])
final_out = []
for pil_image in images:
final_out.append(pil_image)
for single_faceid_cond_kwargs in cond_faceids:
final_out.extend(single_faceid_cond_kwargs["refs"])
if "mix_refs" in single_faceid_cond_kwargs:
final_out.extend(single_faceid_cond_kwargs["mix_refs"])
return final_out
def image_to_single_id_generation_process(
pil_faceid_image=None, pil_faceid_supp_images=None,
pil_faceid_mix_image_1=None, mix_scale_1=0.0,
pil_faceid_mix_image_2=None, mix_scale_2=0.0,
faceid_scale=0.0, face_structure_scale=0.0,
pil_ip_image=None, ip_scale=1.0,
num_samples=1, seed=-1, image_resolution="768x512",
inference_steps=25,
):
if seed == -1:
seed = None
single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,
pil_faceid_supp_images,
[pil_faceid_mix_image_1, pil_faceid_mix_image_2],
[mix_scale_1, mix_scale_2])
cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []
h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
# Image Prompt and Style Aligned
if pil_ip_image is None:
gr.Error("Please upload a reference image")
attn_args.reset()
pil_ip_image = pil_ip_image.convert("RGB").resize((w, h))
zts = inversion.ddim_inversion(uniportrait_pipeline.pipe, np.array(pil_ip_image), "", inference_steps, 2)
zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0)
# reset attn args
attn_args.reset()
# set ip condition
attn_args.ip_scale = ip_scale if pil_ip_image else 0.0
# set faceid condition
attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # lora for single faceid
attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # lora for >1 faceids
attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
attn_args.num_faceids = len(cond_faceids)
# set shared self-attn
attn_args.enable_share_attn = True
attn_args.shared_score_shift = -0.5
print(attn_args)
prompt = [""] * (1 + num_samples)
negative_prompt = [""] * (1 + num_samples)
images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
pil_ip_image=pil_ip_image,
cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
seed=seed, guidance_scale=7.5,
num_inference_steps=inference_steps,
image=[torch.zeros([1, 3, h, w])],
controlnet_conditioning_scale=[0.0],
zT=zT, callback_on_step_end=inversion_callback)
images = images[1:]
final_out = []
for pil_image in images:
final_out.append(pil_image)
for single_faceid_cond_kwargs in cond_faceids:
final_out.extend(single_faceid_cond_kwargs["refs"])
if "mix_refs" in single_faceid_cond_kwargs:
final_out.extend(single_faceid_cond_kwargs["mix_refs"])
return final_out
def text_to_single_id_generation_block():
gr.Markdown("## Text-to-Single-ID Generation")
gr.HTML(text_to_single_id_description)
gr.HTML(text_to_single_id_tips)
with gr.Row():
with gr.Column(scale=1, min_width=100):
prompt = gr.Textbox(value="", label='Prompt', lines=2)
negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt')
run_button = gr.Button(value="Run")
with gr.Accordion("Options", open=True):
image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
label="Image Resolution (HxW)")
seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
value=2147483647)
num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)
faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0,
step=0.01, value=0.1)
with gr.Column(scale=2, min_width=100):
with gr.Row(equal_height=False):
pil_faceid_image = gr.Image(type="pil", label="ID Image")
with gr.Accordion("ID Supplements", open=True):
with gr.Row():
pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"],
type="binary", label="Additional ID Images")
with gr.Row():
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1")
mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2")
mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
with gr.Row():
example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True,
format="png")
with gr.Row():
examples = [
[
"A young man with short black hair, wearing a black hoodie with a hood, was paired with a blue denim jacket with yellow details.",
"assets/examples/1-newton.jpg",
"assets/examples/1-output-1.png",
],
]
gr.Examples(
label="Examples",
examples=examples,
fn=lambda x, y, z: (x, y),
inputs=[prompt, pil_faceid_image, example_output],
outputs=[prompt, pil_faceid_image]
)
ips = [
pil_faceid_image, pil_faceid_supp_images,
pil_faceid_mix_image_1, mix_scale_1,
pil_faceid_mix_image_2, mix_scale_2,
faceid_scale, face_structure_scale,
prompt, negative_prompt,
num_samples, seed,
image_resolution,
inference_steps,
]
run_button.click(fn=text_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])
def text_to_multi_id_generation_block():
gr.Markdown("## Text-to-Multi-ID Generation")
gr.HTML(text_to_multi_id_description)
gr.HTML(text_to_multi_id_tips)
with gr.Row():
with gr.Column(scale=1, min_width=100):
prompt = gr.Textbox(value="", label='Prompt', lines=2)
negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt')
run_button = gr.Button(value="Run")
with gr.Accordion("Options", open=True):
image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
label="Image Resolution (HxW)")
seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
value=2147483647)
num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)
faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0,
step=0.01, value=0.3)
with gr.Column(scale=2, min_width=100):
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=100):
pil_faceid_image_1 = gr.Image(type="pil", label="First ID")
with gr.Accordion("First ID Supplements", open=False):
with gr.Row():
pil_faceid_supp_images_1 = gr.File(file_count="multiple", file_types=["image"],
type="binary", label="Additional ID Images")
with gr.Row():
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_1_1 = gr.Image(type="pil", label="Mix ID 1")
mix_scale_1_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01,
value=0.0)
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_1_2 = gr.Image(type="pil", label="Mix ID 2")
mix_scale_1_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01,
value=0.0)
with gr.Column(scale=1, min_width=100):
pil_faceid_image_2 = gr.Image(type="pil", label="Second ID")
with gr.Accordion("Second ID Supplements", open=False):
with gr.Row():
pil_faceid_supp_images_2 = gr.File(file_count="multiple", file_types=["image"],
type="binary", label="Additional ID Images")
with gr.Row():
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_2_1 = gr.Image(type="pil", label="Mix ID 1")
mix_scale_2_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01,
value=0.0)
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_2_2 = gr.Image(type="pil", label="Mix ID 2")
mix_scale_2_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01,
value=0.0)
with gr.Row():
example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True,
format="png")
with gr.Row():
examples = [
[
"The two female models, fair-skinned, wore a white V-neck short-sleeved top with a light smile on the corners of their mouths. The background was off-white.",
"assets/examples/2-stylegan2-ffhq-0100.png",
"assets/examples/2-stylegan2-ffhq-0293.png",
"assets/examples/2-output-1.png",
],
]
gr.Examples(
label="Examples",
examples=examples,
inputs=[prompt, pil_faceid_image_1, pil_faceid_image_2, example_output],
)
ips = [
pil_faceid_image_1, pil_faceid_supp_images_1,
pil_faceid_mix_image_1_1, mix_scale_1_1,
pil_faceid_mix_image_1_2, mix_scale_1_2,
pil_faceid_image_2, pil_faceid_supp_images_2,
pil_faceid_mix_image_2_1, mix_scale_2_1,
pil_faceid_mix_image_2_2, mix_scale_2_2,
faceid_scale, face_structure_scale,
prompt, negative_prompt,
num_samples, seed,
image_resolution,
inference_steps,
]
run_button.click(fn=text_to_multi_id_generation_process, inputs=ips, outputs=[result_gallery])
def image_to_single_id_generation_block():
gr.Markdown("## Image-to-Single-ID Generation")
gr.HTML(image_to_single_id_description)
gr.HTML(image_to_single_id_tips)
with gr.Row():
with gr.Column(scale=1, min_width=100):
run_button = gr.Button(value="Run")
seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
value=2147483647)
num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
label="Image Resolution (HxW)")
inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)
ip_scale = gr.Slider(label="Reference Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01,
value=0.3)
with gr.Column(scale=3, min_width=100):
with gr.Row(equal_height=False):
pil_ip_image = gr.Image(type="pil", label="Portrait Reference")
pil_faceid_image = gr.Image(type="pil", label="ID Image")
with gr.Accordion("ID Supplements", open=True):
with gr.Row():
pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"],
type="binary", label="Additional ID Images")
with gr.Row():
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1")
mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
with gr.Column(scale=1, min_width=100):
pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2")
mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
with gr.Row():
with gr.Column(scale=3, min_width=100):
example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4,
preview=True, format="png")
with gr.Row():
examples = [
[
"assets/examples/3-style-1.png",
"assets/examples/3-stylegan2-ffhq-0293.png",
0.7,
0.3,
"assets/examples/3-output-1.png",
],
[
"assets/examples/3-style-1.png",
"assets/examples/3-stylegan2-ffhq-0293.png",
0.6,
0.0,
"assets/examples/3-output-2.png",
],
[
"assets/examples/3-style-2.jpg",
"assets/examples/3-stylegan2-ffhq-0381.png",
0.7,
0.3,
"assets/examples/3-output-3.png",
],
[
"assets/examples/3-style-3.jpg",
"assets/examples/3-stylegan2-ffhq-0381.png",
0.6,
0.0,
"assets/examples/3-output-4.png",
],
]
gr.Examples(
label="Examples",
examples=examples,
fn=lambda x, y, z, w, v: (x, y, z, w),
inputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale, example_output],
outputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale]
)
ips = [
pil_faceid_image, pil_faceid_supp_images,
pil_faceid_mix_image_1, mix_scale_1,
pil_faceid_mix_image_2, mix_scale_2,
faceid_scale, face_structure_scale,
pil_ip_image, ip_scale,
num_samples, seed, image_resolution,
inference_steps,
]
run_button.click(fn=image_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])
if __name__ == "__main__":
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
title = r"""
<div style="text-align: center;">
<h1> UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization </h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/pdf/2408.05939"><img src="https://img.shields.io/badge/arXiv-2408.05939-red"></a>
<a href='https://aigcdesigngroup.github.io/UniPortrait-Page/'><img src='https://img.shields.io/badge/Project_Page-UniPortrait-green' alt='Project Page'></a>
<a href="https://github.com/junjiehe96/UniPortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>
</br>
</div>
"""
title_description = r"""
This is the <b>official 🤗 Gradio demo</b> for <a href='https://arxiv.org/pdf/2408.05939' target='_blank'><b>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization</b></a>.<br>
The demo provides three capabilities: text-to-single-ID personalization, text-to-multi-ID personalization, and image-to-single-ID personalization. All of these are based on the <b>Stable Diffusion v1-5</b> model. Feel free to give them a try! 😊
"""
text_to_single_id_description = r"""🚀🚀🚀Quick start:<br>
1. Enter a text prompt (Chinese or English), Upload an image with a face, and Click the <b>Run</b> button. 🤗<br>
"""
text_to_single_id_tips = r"""💡💡💡Tips:<br>
1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>
2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".<br>
3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).<br>
"""
text_to_multi_id_description = r"""🚀🚀🚀Quick start:<br>
1. Enter a text prompt (Chinese or English), Upload an image with a face in "First ID" and "Second ID" blocks respectively, and Click the <b>Run</b> button. 🤗<br>
"""
text_to_multi_id_tips = r"""💡💡💡Tips:<br>
1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>
2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".<br>
3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.3~0.7) and "Face Structure Scale" (0.0~0.4).<br>
"""
image_to_single_id_description = r"""🚀🚀🚀Quick start: Upload an image as the portrait reference (can be any style), Upload a face image, and Click the <b>Run</b> button. 🤗<br>"""
image_to_single_id_tips = r"""💡💡💡Tips:<br>
1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>
2. It's a good idea to upload multiple reference photos of your face to improve ID consistency. Additional references can be uploaded in the "ID supplements".<br>
3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the portrait reference and ID alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).<br>
"""
citation = r"""
---
📝 **Citation**
<br>
If our work is helpful for your research or applications, please cite us via:
```bibtex
@article{he2024uniportrait,
title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},
author={He, Junjie and Geng, Yifeng and Bo, Liefeng},
journal={arXiv preprint arXiv:2408.05939},
year={2024}
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to open an issue or directly reach us out at <b>hejunjie1103@gmail.com</b>.
"""
block = gr.Blocks(title="UniPortrait").queue()
with block:
gr.HTML(title)
gr.HTML(title_description)
with gr.TabItem("Text-to-Single-ID"):
text_to_single_id_generation_block()
with gr.TabItem("Text-to-Multi-ID"):
text_to_multi_id_generation_block()
with gr.TabItem("Image-to-Single-ID (Stylization)"):
image_to_single_id_generation_block()
gr.Markdown(citation)
block.launch(server_name='0.0.0.0', share=False, server_port=port, allowed_paths=["/"])
================================================
FILE: requirements.txt
================================================
diffusers
gradio
onnxruntime-gpu
insightface
torch
tqdm
transformers
================================================
FILE: uniportrait/__init__.py
================================================
================================================
FILE: uniportrait/curricular_face/__init__.py
================================================
================================================
FILE: uniportrait/curricular_face/backbone/__init__.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone
from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50,
IR_SE_101, IR_SE_152, IR_SE_200)
from .model_resnet import ResNet_50, ResNet_101, ResNet_152
_model_dict = {
'ResNet_50': ResNet_50,
'ResNet_101': ResNet_101,
'ResNet_152': ResNet_152,
'IR_18': IR_18,
'IR_34': IR_34,
'IR_50': IR_50,
'IR_101': IR_101,
'IR_152': IR_152,
'IR_200': IR_200,
'IR_SE_50': IR_SE_50,
'IR_SE_101': IR_SE_101,
'IR_SE_152': IR_SE_152,
'IR_SE_200': IR_SE_200
}
def get_model(key):
""" Get different backbone network by key,
support ResNet50, ResNet_101, ResNet_152
IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
"""
if key in _model_dict.keys():
return _model_dict[key]
else:
raise KeyError('not support model {}'.format(key))
================================================
FILE: uniportrait/curricular_face/backbone/common.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py
import torch.nn as nn
from torch.nn import (Conv2d, Module, ReLU,
Sigmoid)
def initialize_weights(modules):
""" Weight initilize, conv2d and linear is initialized with kaiming_normal
"""
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
class Flatten(Module):
""" Flat tensor
"""
def forward(self, input):
return input.view(input.size(0), -1)
class SEModule(Module):
""" SE block
"""
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels,
channels // reduction,
kernel_size=1,
padding=0,
bias=False)
nn.init.xavier_uniform_(self.fc1.weight.data)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
================================================
FILE: uniportrait/curricular_face/backbone/model_irse.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py
from collections import namedtuple
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
MaxPool2d, Module, PReLU, Sequential)
from .common import Flatten, SEModule, initialize_weights
class BasicBlockIR(Module):
""" BasicBlock for IRNet
"""
def __init__(self, in_channel, depth, stride):
super(BasicBlockIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BottleneckIR(Module):
""" BasicBlock with bottleneck for IRNet
"""
def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
reduction_channel = depth // 4
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(
in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
BatchNorm2d(reduction_channel), PReLU(reduction_channel),
Conv2d(
reduction_channel,
reduction_channel, (3, 3), (1, 1),
1,
bias=False), BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BasicBlockIRSE(BasicBlockIR):
def __init__(self, in_channel, depth, stride):
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module('se_block', SEModule(depth, 16))
class BottleneckIRSE(BottleneckIR):
def __init__(self, in_channel, depth, stride):
super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module('se_block', SEModule(depth, 16))
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
'''A named tuple describing a ResNet block.'''
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + \
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 18:
blocks = [
get_block(in_channel=64, depth=64, num_units=2),
get_block(in_channel=64, depth=128, num_units=2),
get_block(in_channel=128, depth=256, num_units=2),
get_block(in_channel=256, depth=512, num_units=2)
]
elif num_layers == 34:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=6),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=8),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
elif num_layers == 200:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=24),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
return blocks
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir'):
""" Args:
input_size: input_size of backbone
num_layers: num_layers of backbone
mode: support ir or irse
"""
super(Backbone, self).__init__()
assert input_size[0] in [112, 224], \
'input_size should be [112, 112] or [224, 224]'
assert num_layers in [18, 34, 50, 100, 152, 200], \
'num_layers should be 18, 34, 50, 100 or 152'
assert mode in ['ir', 'ir_se'], \
'mode should be ir or ir_se'
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
blocks = get_blocks(num_layers)
if num_layers <= 100:
if mode == 'ir':
unit_module = BasicBlockIR
elif mode == 'ir_se':
unit_module = BasicBlockIRSE
output_channel = 512
else:
if mode == 'ir':
unit_module = BottleneckIR
elif mode == 'ir_se':
unit_module = BottleneckIRSE
output_channel = 2048
if input_size[0] == 112:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 7 * 7, 512),
BatchNorm1d(512, affine=False))
else:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 14 * 14, 512),
BatchNorm1d(512, affine=False))
modules = []
mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101
for block in blocks:
if len(mid_layer_indices) == 0:
mid_layer_indices.append(len(block) - 1)
else:
mid_layer_indices.append(len(block) + mid_layer_indices[-1])
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
self.mid_layer_indices = mid_layer_indices[-4:]
initialize_weights(self.modules())
def forward(self, x, return_mid_feats=False):
x = self.input_layer(x)
if not return_mid_feats:
x = self.body(x)
x = self.output_layer(x)
return x
else:
out_feats = []
for idx, module in enumerate(self.body):
x = module(x)
if idx in self.mid_layer_indices:
out_feats.append(x)
x = self.output_layer(x)
return x, out_feats
def IR_18(input_size):
""" Constructs a ir-18 model.
"""
model = Backbone(input_size, 18, 'ir')
return model
def IR_34(input_size):
""" Constructs a ir-34 model.
"""
model = Backbone(input_size, 34, 'ir')
return model
def IR_50(input_size):
""" Constructs a ir-50 model.
"""
model = Backbone(input_size, 50, 'ir')
return model
def IR_101(input_size):
""" Constructs a ir-101 model.
"""
model = Backbone(input_size, 100, 'ir')
return model
def IR_152(input_size):
""" Constructs a ir-152 model.
"""
model = Backbone(input_size, 152, 'ir')
return model
def IR_200(input_size):
""" Constructs a ir-200 model.
"""
model = Backbone(input_size, 200, 'ir')
return model
def IR_SE_50(input_size):
""" Constructs a ir_se-50 model.
"""
model = Backbone(input_size, 50, 'ir_se')
return model
def IR_SE_101(input_size):
""" Constructs a ir_se-101 model.
"""
model = Backbone(input_size, 100, 'ir_se')
return model
def IR_SE_152(input_size):
""" Constructs a ir_se-152 model.
"""
model = Backbone(input_size, 152, 'ir_se')
return model
def IR_SE_200(input_size):
""" Constructs a ir_se-200 model.
"""
model = Backbone(input_size, 200, 'ir_se')
return model
================================================
FILE: uniportrait/curricular_face/backbone/model_resnet.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py
import torch.nn as nn
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
MaxPool2d, Module, ReLU, Sequential)
from .common import initialize_weights
def conv3x3(in_planes, out_planes, stride=1):
""" 3x3 convolution with padding
"""
return Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
def conv1x1(in_planes, out_planes, stride=1):
""" 1x1 convolution
"""
return Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn2 = BatchNorm2d(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = BatchNorm2d(planes * self.expansion)
self.relu = ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(Module):
""" ResNet backbone
"""
def __init__(self, input_size, block, layers, zero_init_residual=True):
""" Args:
input_size: input_size of backbone
block: block function
layers: layers in each block
"""
super(ResNet, self).__init__()
assert input_size[0] in [112, 224], \
'input_size should be [112, 112] or [224, 224]'
self.inplanes = 64
self.conv1 = Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = BatchNorm2d(64)
self.relu = ReLU(inplace=True)
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn_o1 = BatchNorm2d(2048)
self.dropout = Dropout()
if input_size[0] == 112:
self.fc = Linear(2048 * 4 * 4, 512)
else:
self.fc = Linear(2048 * 7 * 7, 512)
self.bn_o2 = BatchNorm1d(512)
initialize_weights(self.modules)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn_o1(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.bn_o2(x)
return x
def ResNet_50(input_size, **kwargs):
""" Constructs a ResNet-50 model.
"""
model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def ResNet_101(input_size, **kwargs):
""" Constructs a ResNet-101 model.
"""
model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def ResNet_152(input_size, **kwargs):
""" Constructs a ResNet-152 model.
"""
model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)
return model
================================================
FILE: uniportrait/curricular_face/inference.py
================================================
import glob
import os
import cv2
import numpy as np
import torch
from tqdm.auto import tqdm
from .backbone import get_model
@torch.no_grad()
def inference(name, weight, src_norm_dir):
face_model = get_model(name)([112, 112])
face_model.load_state_dict(torch.load(weight, map_location="cpu"))
face_model = face_model.to("cpu")
face_model.eval()
id2src_norm = {}
for src_id in sorted(list(os.listdir(src_norm_dir))):
id2src_norm[src_id] = sorted(list(glob.glob(f"{os.path.join(src_norm_dir, src_id)}/*")))
total_sims = []
for id_name in tqdm(id2src_norm):
src_face_embeddings = []
for src_img_path in id2src_norm[id_name]:
src_img = cv2.imread(src_img_path)
src_img = cv2.resize(src_img, (112, 112))
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
src_img = np.transpose(src_img, (2, 0, 1))
src_img = torch.from_numpy(src_img).unsqueeze(0).float()
src_img.div_(255).sub_(0.5).div_(0.5)
embedding = face_model(src_img).detach().cpu().numpy()[0]
embedding = embedding / np.linalg.norm(embedding)
src_face_embeddings.append(embedding) # 512
num = len(src_face_embeddings)
src_face_embeddings = np.stack(src_face_embeddings) # n, 512
sim = src_face_embeddings @ src_face_embeddings.T # n, n
mean_sim = (np.sum(sim) - num * 1.0) / ((num - 1) * num)
print(f"{id_name}: {mean_sim}")
total_sims.append(mean_sim)
return np.mean(total_sims)
if __name__ == "__main__":
name = 'IR_101'
weight = "models/glint360k_curricular_face_r101_backbone.bin"
src_norm_dir = "/disk1/hejunjie.hjj/data/normface-AFD-id-20"
mean_sim = inference(name, weight, src_norm_dir)
print(f"total: {mean_sim:.4f}") # total: 0.6299
================================================
FILE: uniportrait/inversion.py
================================================
# modified from https://github.com/google/style-aligned/blob/main/inversion.py
from __future__ import annotations
from typing import Callable
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
from tqdm import tqdm
T = torch.Tensor
InversionCallback = Callable[[StableDiffusionPipeline, int, T, dict[str, T]], dict[str, T]]
def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: str) -> tuple[dict[str, T], T]:
device = model._execution_device
prompt_embeds = model._encode_prompt(
prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True,
negative_prompt="")
return prompt_embeds
def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T:
model.vae.to(dtype=torch.float32)
image = torch.from_numpy(image).float() / 255.
image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor
model.vae.to(dtype=torch.float16)
return latent
def _next_step(model: StableDiffusionPipeline, model_output: T, timestep: int, sample: T) -> T:
timestep, next_timestep = min(
timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
alpha_prod_t = model.scheduler.alphas_cumprod[
int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
return next_sample
def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, context: T, guidance_scale: float):
latents_input = torch.cat([latent] * 2)
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# latents = next_step(model, noise_pred, t, latent)
return noise_pred
def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scale) -> T:
all_latent = [z0]
text_embedding = _encode_text_with_negative(model, prompt)
image_embedding = torch.zeros_like(text_embedding[:, :1]).repeat(1, 4, 1) # for ip embedding
text_embedding = torch.cat([text_embedding, image_embedding], dim=1)
latent = z0.clone().detach().half()
for i in tqdm(range(model.scheduler.num_inference_steps)):
t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale)
latent = _next_step(model, noise_pred, t, latent)
all_latent.append(latent)
return torch.cat(all_latent).flip(0)
def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:
def callback_on_step_end(pipeline: StableDiffusionPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[
str, T]:
latents = callback_kwargs['latents']
latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
return {'latents': latents}
return zts[offset], callback_on_step_end
@torch.no_grad()
def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int,
guidance_scale, ) -> T:
z0 = _encode_image(model, x0)
model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
zs = _ddim_loop(model, z0, prompt, guidance_scale)
return zs
================================================
FILE: uniportrait/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, attention_mask=None):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latents (torch.Tensor): latent features
shape (b, n2, D)
attention_mask (torch.Tensor): attention mask
shape (b, n1, 1)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
if attention_mask is not None:
attention_mask = attention_mask.transpose(1, 2) # (b, 1, n1)
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :, :1]).repeat(1, 1, l)],
dim=2) # b, 1, n1+n2
attention_mask = (attention_mask - 1.) * 100. # 0 means kept and -100 means dropped
attention_mask = attention_mask.unsqueeze(1)
weight = weight + attention_mask # b, h, n2, n1+n2
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class UniPortraitFaceIDResampler(torch.nn.Module):
def __init__(
self,
intrinsic_id_embedding_dim=512,
structure_embedding_dim=64 + 128 + 256 + 1280,
num_tokens=16,
depth=6,
dim=768,
dim_head=64,
heads=12,
ff_mult=4,
output_dim=768,
):
super().__init__()
self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5)
self.proj_id = torch.nn.Sequential(
torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim),
)
self.proj_clip = torch.nn.Sequential(
torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(structure_embedding_dim * 2, dim),
)
self.layers = torch.nn.ModuleList([])
for _ in range(depth):
self.layers.append(
torch.nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.proj_out = torch.nn.Linear(dim, output_dim)
self.norm_out = torch.nn.LayerNorm(output_dim)
def forward(
self,
intrinsic_id_embeds,
structure_embeds,
structure_scale=1.0,
intrinsic_id_attention_mask=None,
structure_attention_mask=None
):
latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1)
intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds)
structure_embeds = self.proj_clip(structure_embeds)
for attn1, attn2, ff in self.layers:
latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents
latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
================================================
FILE: uniportrait/uniportrait_attention_processor.py
================================================
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.lora import LoRALinearLayer
class AttentionArgs(object):
def __init__(self) -> None:
# ip condition
self.ip_scale = 0.0
self.ip_mask = None # ip attention mask
# faceid condition
self.lora_scale = 0.0 # lora for single faceid
self.multi_id_lora_scale = 0.0 # lora for multiple faceids
self.faceid_scale = 0.0
self.num_faceids = 0
self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map
# style aligned
self.enable_share_attn: bool = False
self.adain_queries_and_keys: bool = False
self.shared_score_scale: float = 1.0
self.shared_score_shift: float = 0.0
def reset(self):
# ip condition
self.ip_scale = 0.0
self.ip_mask = None # ip attention mask
# faceid condition
self.lora_scale = 0.0 # lora for single faceid
self.multi_id_lora_scale = 0.0 # lora for multiple faceids
self.faceid_scale = 0.0
self.num_faceids = 0
self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map
# style aligned
self.enable_share_attn: bool = False
self.adain_queries_and_keys: bool = False
self.shared_score_scale: float = 1.0
self.shared_score_shift: float = 0.0
def __repr__(self):
indent_str = ' '
s = f",\n{indent_str}".join(f"{attr}={value}" for attr, value in vars(self).items())
return self.__class__.__name__ + '(' + f'\n{indent_str}' + s + ')'
attn_args = AttentionArgs()
def expand_first(feat, scale=1., ):
b = feat.shape[0]
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
if scale == 1:
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
else:
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
return feat_style.reshape(*feat.shape)
def concat_first(feat, dim=2, scale=1.):
feat_style = expand_first(feat, scale=scale)
return torch.cat((feat, feat_style), dim=dim)
def calc_mean_std(feat, eps: float = 1e-5):
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
feat_mean = feat.mean(dim=-2, keepdims=True)
return feat_mean, feat_std
def adain(feat):
feat_mean, feat_std = calc_mean_std(feat)
feat_style_mean = expand_first(feat_mean)
feat_style_std = expand_first(feat_std)
feat = (feat - feat_mean) / feat_std
feat = feat * feat_style_std + feat_style_mean
return feat
class UniPortraitLoRAAttnProcessor2_0(nn.Module):
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
rank=128,
network_alpha=None,
):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn_args.lora_scale > 0.0:
query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)
key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)
value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)
elif attn_args.multi_id_lora_scale > 0.0:
query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)
key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)
value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn_args.enable_share_attn:
if attn_args.adain_queries_and_keys:
query = adain(query)
key = adain(key)
key = concat_first(key, -2, scale=attn_args.shared_score_scale)
value = concat_first(value, -2)
if attn_args.shared_score_shift != 0:
attention_mask = torch.zeros_like(key[:, :, :, :1]).transpose(-1, -2) # b, h, 1, k
attention_mask[:, :, :, query.shape[2]:] += attn_args.shared_score_shift
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
)
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
output_hidden_states = attn.to_out[0](hidden_states)
if attn_args.lora_scale > 0.0:
output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)
elif attn_args.multi_id_lora_scale > 0.0:
output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(
hidden_states)
hidden_states = output_hidden_states
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class UniPortraitLoRAIPAttnProcessor2_0(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None,
num_ip_tokens=4, num_faceid_tokens=16):
super().__init__()
self.num_ip_tokens = num_ip_tokens
self.num_faceid_tokens = num_faceid_tokens
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_k_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_q_router = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2),
nn.GELU(),
nn.Linear(hidden_size * 2, hidden_size, bias=False),
)
self.to_k_router = nn.Sequential(
nn.Linear(cross_attention_dim or hidden_size, (cross_attention_dim or hidden_size) * 2),
nn.GELU(),
nn.Linear((cross_attention_dim or hidden_size) * 2, hidden_size, bias=False),
)
self.aggr_router = nn.Linear(num_faceid_tokens, 1)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# split hidden states
faceid_end = encoder_hidden_states.shape[1]
ip_end = faceid_end - self.num_faceid_tokens * attn_args.num_faceids
text_end = ip_end - self.num_ip_tokens
prompt_hidden_states = encoder_hidden_states[:, :text_end]
ip_hidden_states = encoder_hidden_states[:, text_end: ip_end]
faceid_hidden_states = encoder_hidden_states[:, ip_end: faceid_end]
encoder_hidden_states = prompt_hidden_states
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# for router
if attn_args.num_faceids > 1:
router_query = self.to_q_router(hidden_states) # bs, s*s, dim
router_hidden_states = faceid_hidden_states.reshape(batch_size, attn_args.num_faceids,
self.num_faceid_tokens, -1) # bs, num, id_tokens, d
router_hidden_states = self.aggr_router(router_hidden_states.transpose(-1, -2)).squeeze(-1) # bs, num, d
router_key = self.to_k_router(router_hidden_states) # bs, num, dim
router_logits = torch.bmm(router_query, router_key.transpose(-1, -2)) # bs, s*s, num
index = router_logits.max(dim=-1, keepdim=True)[1]
routing_map = torch.zeros_like(router_logits).scatter_(-1, index, 1.0)
routing_map = routing_map.transpose(1, 2).unsqueeze(-1) # bs, num, s*s, 1
else:
routing_map = hidden_states.new_ones(size=(1, 1, hidden_states.shape[1], 1))
# for text
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn_args.lora_scale > 0.0:
query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)
key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)
value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)
elif attn_args.multi_id_lora_scale > 0.0:
query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)
key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)
value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
if attn_args.ip_scale > 0.0:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
if attn_args.ip_mask is not None:
ip_mask = attn_args.ip_mask
h, w = ip_mask.shape[-2:]
ratio = (h * w / query.shape[2]) ** 0.5
ip_mask = torch.nn.functional.interpolate(ip_mask, scale_factor=1 / ratio,
mode='nearest').reshape(
[1, -1, 1])
ip_hidden_states = ip_hidden_states * ip_mask
if attn_args.enable_share_attn:
ip_hidden_states[0] = 0.
ip_hidden_states[batch_size // 2] = 0.
else:
ip_hidden_states = torch.zeros_like(hidden_states)
# for faceid-adapter
if attn_args.faceid_scale > 0.0:
faceid_key = self.to_k_faceid(faceid_hidden_states)
faceid_value = self.to_v_faceid(faceid_hidden_states)
faceid_query = query[:, None].expand(-1, attn_args.num_faceids, -1, -1,
-1) # 2*bs, num, heads, s*s, dim/heads
faceid_key = faceid_key.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,
head_dim).transpose(2, 3)
faceid_value = faceid_value.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,
head_dim).transpose(2, 3)
faceid_hidden_states = F.scaled_dot_product_attention(
faceid_query, faceid_key, faceid_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale
) # 2*bs, num, heads, s*s, dim/heads
faceid_hidden_states = faceid_hidden_states.transpose(2, 3).reshape(batch_size, attn_args.num_faceids, -1,
attn.heads * head_dim)
faceid_hidden_states = faceid_hidden_states.to(query.dtype) # 2*bs, num, s*s, dim
if attn_args.faceid_mask is not None:
faceid_mask = attn_args.faceid_mask # 1, num, h, w
h, w = faceid_mask.shape[-2:]
ratio = (h * w / query.shape[2]) ** 0.5
faceid_mask = F.interpolate(faceid_mask, scale_factor=1 / ratio,
mode='bilinear').flatten(2).unsqueeze(-1) # 1, num, s*s, 1
faceid_mask = faceid_mask / faceid_mask.sum(1, keepdim=True).clip(min=1e-3) # 1, num, s*s, 1
faceid_hidden_states = (faceid_mask * faceid_hidden_states).sum(1) # 2*bs, s*s, dim
else:
faceid_hidden_states = (routing_map * faceid_hidden_states).sum(1) # 2*bs, s*s, dim
if attn_args.enable_share_attn:
faceid_hidden_states[0] = 0.
faceid_hidden_states[batch_size // 2] = 0.
else:
faceid_hidden_states = torch.zeros_like(hidden_states)
hidden_states = hidden_states + \
attn_args.ip_scale * ip_hidden_states + \
attn_args.faceid_scale * faceid_hidden_states
# linear proj
output_hidden_states = attn.to_out[0](hidden_states)
if attn_args.lora_scale > 0.0:
output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)
elif attn_args.multi_id_lora_scale > 0.0:
output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(
hidden_states)
hidden_states = output_hidden_states
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# for controlnet
class UniPortraitCNAttnProcessor2_0:
def __init__(self, num_ip_tokens=4, num_faceid_tokens=16):
self.num_ip_tokens = num_ip_tokens
self.num_faceid_tokens = num_faceid_tokens
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
text_end = encoder_hidden_states.shape[1] - self.num_faceid_tokens * attn_args.num_faceids \
- self.num_ip_tokens
encoder_hidden_states = encoder_hidden_states[:, :text_end] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
================================================
FILE: uniportrait/uniportrait_pipeline.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import ControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .curricular_face.backbone import get_model
from .resampler import UniPortraitFaceIDResampler
from .uniportrait_attention_processor import UniPortraitCNAttnProcessor2_0 as UniPortraitCNAttnProcessor
from .uniportrait_attention_processor import UniPortraitLoRAAttnProcessor2_0 as UniPortraitLoRAAttnProcessor
from .uniportrait_attention_processor import UniPortraitLoRAIPAttnProcessor2_0 as UniPortraitLoRAIPAttnProcessor
class ImageProjModel(nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds # b, c
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens,
self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class UniPortraitPipeline:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_backbone_ckpt=None, uniportrait_faceid_ckpt=None,
uniportrait_router_ckpt=None, num_ip_tokens=4, num_faceid_tokens=16,
lora_rank=128, device=torch.device("cuda"), torch_dtype=torch.float16):
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.uniportrait_faceid_ckpt = uniportrait_faceid_ckpt
self.uniportrait_router_ckpt = uniportrait_router_ckpt
self.num_ip_tokens = num_ip_tokens
self.num_faceid_tokens = num_faceid_tokens
self.lora_rank = lora_rank
self.device = device
self.torch_dtype = torch_dtype
self.pipe = sd_pipe.to(self.device)
# load clip image encoder
self.clip_image_processor = CLIPImageProcessor(size={"shortest_edge": 224}, do_center_crop=False,
use_square_size=True)
self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=self.torch_dtype)
# load face backbone
self.facerecog_model = get_model("IR_101")([112, 112])
self.facerecog_model.load_state_dict(torch.load(face_backbone_ckpt, map_location="cpu"))
self.facerecog_model = self.facerecog_model.to(self.device, dtype=torch_dtype)
self.facerecog_model.eval()
# image proj model
self.image_proj_model = self.init_image_proj()
# faceid proj model
self.faceid_proj_model = self.init_faceid_proj()
# set uniportrait and ip adapter
self.set_uniportrait_and_ip_adapter()
# load uniportrait and ip adapter
self.load_uniportrait_and_ip_adapter()
def init_image_proj(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.clip_image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_ip_tokens,
).to(self.device, dtype=self.torch_dtype)
return image_proj_model
def init_faceid_proj(self):
faceid_proj_model = UniPortraitFaceIDResampler(
intrinsic_id_embedding_dim=512,
structure_embedding_dim=64 + 128 + 256 + self.clip_image_encoder.config.hidden_size,
num_tokens=16, depth=6,
dim=self.pipe.unet.config.cross_attention_dim, dim_head=64,
heads=12, ff_mult=4,
output_dim=self.pipe.unet.config.cross_attention_dim
).to(self.device, dtype=self.torch_dtype)
return faceid_proj_model
def set_uniportrait_and_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = UniPortraitLoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=self.lora_rank,
).to(self.device, dtype=self.torch_dtype).eval()
else:
attn_procs[name] = UniPortraitLoRAIPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=self.lora_rank,
num_ip_tokens=self.num_ip_tokens,
num_faceid_tokens=self.num_faceid_tokens,
).to(self.device, dtype=self.torch_dtype).eval()
unet.set_attn_processor(attn_procs)
if hasattr(self.pipe, "controlnet"):
if isinstance(self.pipe.controlnet, ControlNetModel):
self.pipe.controlnet.set_attn_processor(
UniPortraitCNAttnProcessor(
num_ip_tokens=self.num_ip_tokens,
num_faceid_tokens=self.num_faceid_tokens,
)
)
elif isinstance(self.pipe.controlnet, MultiControlNetModel):
for module in self.pipe.controlnet.nets:
module.set_attn_processor(
UniPortraitCNAttnProcessor(
num_ip_tokens=self.num_ip_tokens,
num_faceid_tokens=self.num_faceid_tokens,
)
)
else:
raise ValueError
def load_uniportrait_and_ip_adapter(self):
if self.ip_ckpt:
print(f"loading from {self.ip_ckpt}...")
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
ip_layers = nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
if self.uniportrait_faceid_ckpt:
print(f"loading from {self.uniportrait_faceid_ckpt}...")
state_dict = torch.load(self.uniportrait_faceid_ckpt, map_location="cpu")
self.faceid_proj_model.load_state_dict(state_dict["faceid_proj"], strict=True)
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["faceid_adapter"], strict=False)
if self.uniportrait_router_ckpt:
print(f"loading from {self.uniportrait_router_ckpt}...")
state_dict = torch.load(self.uniportrait_router_ckpt, map_location="cpu")
router_state_dict = {}
for k, v in state_dict["faceid_adapter"].items():
if "lora." in k:
router_state_dict[k.replace("lora.", "multi_id_lora.")] = v
elif "router." in k:
router_state_dict[k] = v
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(router_state_dict, strict=False)
@torch.inference_mode()
def get_ip_embeds(self, pil_ip_image):
ip_image = self.clip_image_processor(images=pil_ip_image, return_tensors="pt").pixel_values
ip_image = ip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224), values being normalized
ip_embeds = self.clip_image_encoder(ip_image).image_embeds
ip_prompt_embeds = self.image_proj_model(ip_embeds)
uncond_ip_prompt_embeds = self.image_proj_model(torch.zeros_like(ip_embeds))
return ip_prompt_embeds, uncond_ip_prompt_embeds
@torch.inference_mode()
def get_single_faceid_embeds(self, pil_face_images, face_structure_scale):
face_clip_image = self.clip_image_processor(images=pil_face_images, return_tensors="pt").pixel_values
face_clip_image = face_clip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224)
face_clip_embeds = self.clip_image_encoder(
face_clip_image, output_hidden_states=True).hidden_states[-2][:, 1:] # b, 256, 1280
OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device,
dtype=self.torch_dtype).reshape(-1, 1, 1)
OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device,
dtype=self.torch_dtype).reshape(-1, 1, 1)
facerecog_image = face_clip_image * OPENAI_CLIP_STD + OPENAI_CLIP_MEAN # [0, 1]
facerecog_image = torch.clamp((facerecog_image - 0.5) / 0.5, -1, 1) # [-1, 1]
facerecog_image = F.interpolate(facerecog_image, size=(112, 112), mode="bilinear", align_corners=False)
facerecog_embeds = self.facerecog_model(facerecog_image, return_mid_feats=True)[1]
face_intrinsic_id_embeds = facerecog_embeds[-1] # (b, 512, 7, 7)
face_intrinsic_id_embeds = face_intrinsic_id_embeds.flatten(2).permute(0, 2, 1) # b, 49, 512
facerecog_structure_embeds = facerecog_embeds[:-1] # (b, 64, 56, 56), (b, 128, 28, 28), (b, 256, 14, 14)
facerecog_structure_embeds = torch.cat([
F.interpolate(feat, size=(16, 16), mode="bilinear", align_corners=False)
for feat in facerecog_structure_embeds], dim=1) # b, 448, 16, 16
facerecog_structure_embeds = facerecog_structure_embeds.flatten(2).permute(0, 2, 1) # b, 256, 448
face_structure_embeds = torch.cat([facerecog_structure_embeds, face_clip_embeds], dim=-1) # b, 256, 1728
uncond_face_clip_embeds = self.clip_image_encoder(
torch.zeros_like(face_clip_image[:1]), output_hidden_states=True).hidden_states[-2][:, 1:] # 1, 256, 1280
uncond_face_structure_embeds = torch.cat(
[torch.zeros_like(facerecog_structure_embeds[:1]), uncond_face_clip_embeds], dim=-1) # 1, 256, 1728
faceid_prompt_embeds = self.faceid_proj_model(
face_intrinsic_id_embeds.flatten(0, 1).unsqueeze(0),
face_structure_embeds.flatten(0, 1).unsqueeze(0),
structure_scale=face_structure_scale,
) # [b, 16, 768]
uncond_faceid_prompt_embeds = self.faceid_proj_model(
torch.zeros_like(face_intrinsic_id_embeds[:1]),
uncond_face_structure_embeds,
structure_scale=face_structure_scale,
) # [1, 16, 768]
return faceid_prompt_embeds, uncond_faceid_prompt_embeds
def generate(
self,
prompt=None,
negative_prompt=None,
pil_ip_image=None,
cond_faceids=None,
face_structure_scale=0.0,
seed=-1,
guidance_scale=7.5,
num_inference_steps=30,
zT=None,
**kwargs,
):
"""
Args:
prompt:
negative_prompt:
pil_ip_image:
cond_faceids: [
{
"refs": [PIL.Image] or PIL.Image,
(Optional) "mix_refs": [PIL.Image],
(Optional) "mix_scales": [float],
},
...
]
face_structure_scale:
seed:
guidance_scale:
num_inference_steps:
zT:
**kwargs:
Returns:
"""
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
with torch.inference_mode():
prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True,
negative_prompt=negative_prompt)
num_prompts = prompt_embeds.shape[0]
if pil_ip_image is not None:
ip_prompt_embeds, uncond_ip_prompt_embeds = self.get_ip_embeds(pil_ip_image)
ip_prompt_embeds = ip_prompt_embeds.repeat(num_prompts, 1, 1)
uncond_ip_prompt_embeds = uncond_ip_prompt_embeds.repeat(num_prompts, 1, 1)
else:
ip_prompt_embeds = uncond_ip_prompt_embeds = \
torch.zeros_like(prompt_embeds[:, :1]).repeat(1, self.num_ip_tokens, 1)
prompt_embeds = torch.cat([prompt_embeds, ip_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_ip_prompt_embeds], dim=1)
if cond_faceids and len(cond_faceids) > 0:
all_faceid_prompt_embeds = []
all_uncond_faceid_prompt_embeds = []
for curr_faceid_info in cond_faceids:
refs = curr_faceid_info["refs"]
faceid_prompt_embeds, uncond_faceid_prompt_embeds = \
self.get_single_faceid_embeds(refs, face_structure_scale)
if "mix_refs" in curr_faceid_info:
mix_refs = curr_faceid_info["mix_refs"]
mix_scales = curr_faceid_info["mix_scales"]
master_face_mix_scale = 1.0 - sum(mix_scales)
faceid_prompt_embeds = faceid_prompt_embeds * master_face_mix_scale
for mix_ref, mix_scale in zip(mix_refs, mix_scales):
faceid_mix_prompt_embeds, _ = self.get_single_faceid_embeds(mix_ref, face_structure_scale)
faceid_prompt_embeds = faceid_prompt_embeds + faceid_mix_prompt_embeds * mix_scale
all_faceid_prompt_embeds.append(faceid_prompt_embeds)
all_uncond_faceid_prompt_embeds.append(uncond_faceid_prompt_embeds)
faceid_prompt_embeds = torch.cat(all_faceid_prompt_embeds, dim=1)
uncond_faceid_prompt_embeds = torch.cat(all_uncond_faceid_prompt_embeds, dim=1)
faceid_prompt_embeds = faceid_prompt_embeds.repeat(num_prompts, 1, 1)
uncond_faceid_prompt_embeds = uncond_faceid_prompt_embeds.repeat(num_prompts, 1, 1)
prompt_embeds = torch.cat([prompt_embeds, faceid_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_faceid_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
if zT is not None:
h_, w_ = kwargs["image"][0].shape[-2:]
latents = torch.randn(num_prompts, 4, h_ // 8, w_ // 8, device=self.device, generator=generator,
dtype=self.pipe.unet.dtype)
latents[0] = zT
else:
latents = None
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
latents=latents,
**kwargs,
).images
return images
gitextract_35y_e2yv/
├── .gitignore
├── LICENSE.txt
├── README.md
├── gradio_app.py
├── requirements.txt
└── uniportrait/
├── __init__.py
├── curricular_face/
│ ├── __init__.py
│ ├── backbone/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── model_irse.py
│ │ └── model_resnet.py
│ └── inference.py
├── inversion.py
├── resampler.py
├── uniportrait_attention_processor.py
└── uniportrait_pipeline.py
SYMBOL INDEX (99 symbols across 10 files)
FILE: gradio_app.py
function pad_np_bgr_image (line 69) | def pad_np_bgr_image(np_image, scale=1.25):
function process_faceid_image (line 79) | def process_faceid_image(pil_faceid_image):
function prepare_single_faceid_cond_kwargs (line 106) | def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_...
function text_to_single_id_generation_process (line 140) | def text_to_single_id_generation_process(
function text_to_multi_id_generation_process (line 190) | def text_to_multi_id_generation_process(
function image_to_single_id_generation_process (line 253) | def image_to_single_id_generation_process(
function text_to_single_id_generation_block (line 320) | def text_to_single_id_generation_block():
function text_to_multi_id_generation_block (line 389) | def text_to_multi_id_generation_block():
function image_to_single_id_generation_block (line 477) | def image_to_single_id_generation_block():
FILE: uniportrait/curricular_face/backbone/__init__.py
function get_model (line 24) | def get_model(key):
FILE: uniportrait/curricular_face/backbone/common.py
function initialize_weights (line 8) | def initialize_weights(modules):
class Flatten (line 27) | class Flatten(Module):
method forward (line 31) | def forward(self, input):
class SEModule (line 35) | class SEModule(Module):
method __init__ (line 39) | def __init__(self, channels, reduction):
method forward (line 61) | def forward(self, x):
FILE: uniportrait/curricular_face/backbone/model_irse.py
class BasicBlockIR (line 11) | class BasicBlockIR(Module):
method __init__ (line 15) | def __init__(self, in_channel, depth, stride):
method forward (line 30) | def forward(self, x):
class BottleneckIR (line 37) | class BottleneckIR(Module):
method __init__ (line 41) | def __init__(self, in_channel, depth, stride):
method forward (line 64) | def forward(self, x):
class BasicBlockIRSE (line 71) | class BasicBlockIRSE(BasicBlockIR):
method __init__ (line 73) | def __init__(self, in_channel, depth, stride):
class BottleneckIRSE (line 78) | class BottleneckIRSE(BottleneckIR):
method __init__ (line 80) | def __init__(self, in_channel, depth, stride):
class Bottleneck (line 85) | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
function get_block (line 89) | def get_block(in_channel, depth, num_units, stride=2):
function get_blocks (line 94) | def get_blocks(num_layers):
class Backbone (line 141) | class Backbone(Module):
method __init__ (line 143) | def __init__(self, input_size, num_layers, mode='ir'):
method forward (line 200) | def forward(self, x, return_mid_feats=False):
function IR_18 (line 216) | def IR_18(input_size):
function IR_34 (line 224) | def IR_34(input_size):
function IR_50 (line 232) | def IR_50(input_size):
function IR_101 (line 240) | def IR_101(input_size):
function IR_152 (line 248) | def IR_152(input_size):
function IR_200 (line 256) | def IR_200(input_size):
function IR_SE_50 (line 264) | def IR_SE_50(input_size):
function IR_SE_101 (line 272) | def IR_SE_101(input_size):
function IR_SE_152 (line 280) | def IR_SE_152(input_size):
function IR_SE_200 (line 288) | def IR_SE_200(input_size):
FILE: uniportrait/curricular_face/backbone/model_resnet.py
function conv3x3 (line 10) | def conv3x3(in_planes, out_planes, stride=1):
function conv1x1 (line 22) | def conv1x1(in_planes, out_planes, stride=1):
class Bottleneck (line 29) | class Bottleneck(Module):
method __init__ (line 32) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 44) | def forward(self, x):
class ResNet (line 67) | class ResNet(Module):
method __init__ (line 71) | def __init__(self, input_size, block, layers, zero_init_residual=True):
method _make_layer (line 105) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 121) | def forward(self, x):
function ResNet_50 (line 141) | def ResNet_50(input_size, **kwargs):
function ResNet_101 (line 149) | def ResNet_101(input_size, **kwargs):
function ResNet_152 (line 157) | def ResNet_152(input_size, **kwargs):
FILE: uniportrait/curricular_face/inference.py
function inference (line 13) | def inference(name, weight, src_norm_dir):
FILE: uniportrait/inversion.py
function _encode_text_with_negative (line 16) | def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: s...
function _encode_image (line 24) | def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T:
function _next_step (line 33) | def _next_step(model: StableDiffusionPipeline, model_output: T, timestep...
function _get_noise_pred (line 46) | def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, con...
function _ddim_loop (line 55) | def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scal...
function make_inversion_callback (line 69) | def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallba...
function ddim_inversion (line 80) | def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, promp...
FILE: uniportrait/resampler.py
function FeedForward (line 11) | def FeedForward(dim, mult=4):
function reshape_tensor (line 21) | def reshape_tensor(x, heads):
class PerceiverAttention (line 32) | class PerceiverAttention(nn.Module):
method __init__ (line 33) | def __init__(self, *, dim, dim_head=64, heads=8):
method forward (line 47) | def forward(self, x, latents, attention_mask=None):
class UniPortraitFaceIDResampler (line 89) | class UniPortraitFaceIDResampler(torch.nn.Module):
method __init__ (line 90) | def __init__(
method forward (line 132) | def forward(
FILE: uniportrait/uniportrait_attention_processor.py
class AttentionArgs (line 8) | class AttentionArgs(object):
method __init__ (line 9) | def __init__(self) -> None:
method reset (line 27) | def reset(self):
method __repr__ (line 45) | def __repr__(self):
function expand_first (line 54) | def expand_first(feat, scale=1., ):
function concat_first (line 65) | def concat_first(feat, dim=2, scale=1.):
function calc_mean_std (line 70) | def calc_mean_std(feat, eps: float = 1e-5):
function adain (line 76) | def adain(feat):
class UniPortraitLoRAAttnProcessor2_0 (line 85) | class UniPortraitLoRAAttnProcessor2_0(nn.Module):
method __init__ (line 87) | def __init__(
method __call__ (line 106) | def __call__(
class UniPortraitLoRAIPAttnProcessor2_0 (line 206) | class UniPortraitLoRAIPAttnProcessor2_0(nn.Module):
method __init__ (line 208) | def __init__(self, hidden_size, cross_attention_dim=None, rank=128, ne...
method __call__ (line 243) | def __call__(
class UniPortraitCNAttnProcessor2_0 (line 422) | class UniPortraitCNAttnProcessor2_0:
method __init__ (line 423) | def __init__(self, num_ip_tokens=4, num_faceid_tokens=16):
method __call__ (line 428) | def __call__(
FILE: uniportrait/uniportrait_pipeline.py
class ImageProjModel (line 15) | class ImageProjModel(nn.Module):
method __init__ (line 18) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
method forward (line 26) | def forward(self, image_embeds):
class UniPortraitPipeline (line 34) | class UniPortraitPipeline:
method __init__ (line 36) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_bac...
method init_image_proj (line 73) | def init_image_proj(self):
method init_faceid_proj (line 81) | def init_faceid_proj(self):
method set_uniportrait_and_ip_adapter (line 92) | def set_uniportrait_and_ip_adapter(self):
method load_uniportrait_and_ip_adapter (line 139) | def load_uniportrait_and_ip_adapter(self):
method get_ip_embeds (line 167) | def get_ip_embeds(self, pil_ip_image):
method get_single_faceid_embeds (line 176) | def get_single_faceid_embeds(self, pil_face_images, face_structure_sca...
method generate (line 220) | def generate(
Condensed preview — 16 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (120K chars).
[
{
"path": ".gitignore",
"chars": 1927,
"preview": ".idea/\n.DS_Store\n*.dat\n*.mat\n\ntraining/\nlightning_logs/\nimage_log/\n\n*.png\n*.jpg\n*.jpeg\n*.webp\n\n*.pth\n*.pt\n*.ckpt\n*.safet"
},
{
"path": "LICENSE.txt",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 3126,
"preview": "<div align=\"center\">\n<h1>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personal"
},
{
"path": "gradio_app.py",
"chars": 32564,
"preview": "import os\nfrom io import BytesIO\n\nimport cv2\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom PIL import Image\nf"
},
{
"path": "requirements.txt",
"chars": 69,
"preview": "diffusers\ngradio\nonnxruntime-gpu\ninsightface\ntorch\ntqdm\ntransformers\n"
},
{
"path": "uniportrait/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "uniportrait/curricular_face/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "uniportrait/curricular_face/backbone/__init__.py",
"chars": 1080,
"preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
},
{
"path": "uniportrait/curricular_face/backbone/common.py",
"chars": 1930,
"preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
},
{
"path": "uniportrait/curricular_face/backbone/model_irse.py",
"chars": 9232,
"preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
},
{
"path": "uniportrait/curricular_face/backbone/model_resnet.py",
"chars": 4745,
"preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
},
{
"path": "uniportrait/curricular_face/inference.py",
"chars": 1845,
"preview": "import glob\nimport os\n\nimport cv2\nimport numpy as np\nimport torch\nfrom tqdm.auto import tqdm\n\nfrom .backbone import get_"
},
{
"path": "uniportrait/inversion.py",
"chars": 3836,
"preview": "# modified from https://github.com/google/style-aligned/blob/main/inversion.py\n\nfrom __future__ import annotations\n\nfrom"
},
{
"path": "uniportrait/resampler.py",
"chars": 5328,
"preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
},
{
"path": "uniportrait/uniportrait_attention_processor.py",
"chars": 22528,
"preview": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport to"
},
{
"path": "uniportrait/uniportrait_pipeline.py",
"chars": 16100,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers import ControlNetModel\nfrom diffusers."
}
]
About this extraction
This page contains the full source code of the junjiehe96/UniPortrait GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 16 files (113.0 KB), approximately 28.5k tokens, and a symbol index with 99 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.