Repository: junjiehe96/UniPortrait Branch: main Commit: a4deff2b48e3 Files: 16 Total size: 113.0 KB Directory structure: gitextract_35y_e2yv/ ├── .gitignore ├── LICENSE.txt ├── README.md ├── gradio_app.py ├── requirements.txt └── uniportrait/ ├── __init__.py ├── curricular_face/ │ ├── __init__.py │ ├── backbone/ │ │ ├── __init__.py │ │ ├── common.py │ │ ├── model_irse.py │ │ └── model_resnet.py │ └── inference.py ├── inversion.py ├── resampler.py ├── uniportrait_attention_processor.py └── uniportrait_pipeline.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea/ .DS_Store *.dat *.mat training/ lightning_logs/ image_log/ *.png *.jpg *.jpeg *.webp *.pth *.pt *.ckpt *.safetensors # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: LICENSE.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization

UniPortrait is an innovative human image personalization framework. It customizes single- and multi-ID images in a unified manner, providing high-fidelity identity preservation, extensive facial editability, free-form text description, and no requirement for a predetermined layout. --- ## Release - [2025/05/01] 🔥 We release the code and demo for the `FLUX.1-dev` version of [AnyStory](https://github.com/junjiehe96/AnyStory), a unified approach to general subject personalization. - [2024/10/18] 🔥 We release the inference code and demo, which has simply integrated [ControlNet](https://github.com/lllyasviel/ControlNet) , [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter), and [StyleAligned](https://github.com/google/style-aligned). The weight for this version is consistent with the huggingface space and experiments in the paper. We are now working on generalizing our method to more advanced diffusion models and more general custom concepts. Please stay tuned! - [2024/08/12] 🔥 We release the [technical report](https://arxiv.org/abs/2408.05939) , [project page](https://aigcdesigngroup.github.io/UniPortrait-Page/), and [HuggingFace demo](https://huggingface.co/spaces/Junjie96/UniPortrait) 🤗! ## Quickstart ```shell # Clone repository git clone https://github.com/junjiehe96/UniPortrait.git # install requirements cd UniPortrait pip install -r requirements.txt # download the models git lfs install git clone https://huggingface.co/Junjie96/UniPortrait models # download ip-adapter models # Note: recommend downloading manually. We do not require all IP adapter models. git clone https://huggingface.co/h94/IP-Adapter models/IP-Adapter # then you can use the gradio app python gradio_app.py ``` ## Applications ## **Acknowledgements** This code is built on some excellent repos, including [diffusers](https://github.com/huggingface/diffusers), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [StyleAligned](https://github.com/google/style-aligned). Highly appreciate their great work! ## Cite If you find UniPortrait useful for your research and applications, please cite us using this BibTeX: ```bibtex @article{he2024uniportrait, title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization}, author={He, Junjie and Geng, Yifeng and Bo, Liefeng}, journal={arXiv preprint arXiv:2408.05939}, year={2024} } ``` For any question, please feel free to open an issue or contact us via hejunjie1103@gmail.com. ================================================ FILE: gradio_app.py ================================================ import os from io import BytesIO import cv2 import gradio as gr import numpy as np import torch from PIL import Image from diffusers import DDIMScheduler, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline from insightface.app import FaceAnalysis from insightface.utils import face_align from uniportrait import inversion from uniportrait.uniportrait_attention_processor import attn_args from uniportrait.uniportrait_pipeline import UniPortraitPipeline port = 7860 device = "cuda" torch_dtype = torch.float16 # base base_model_path = "SG161222/Realistic_Vision_V5.1_noVAE" vae_model_path = "stabilityai/sd-vae-ft-mse" controlnet_pose_ckpt = "lllyasviel/control_v11p_sd15_openpose" # specific image_encoder_path = "models/IP-Adapter/models/image_encoder" ip_ckpt = "models/IP-Adapter/models/ip-adapter_sd15.bin" face_backbone_ckpt = "models/glint360k_curricular_face_r101_backbone.bin" uniportrait_faceid_ckpt = "models/uniportrait-faceid_sd15.bin" uniportrait_router_ckpt = "models/uniportrait-router_sd15.bin" # load controlnet pose_controlnet = ControlNetModel.from_pretrained(controlnet_pose_ckpt, torch_dtype=torch_dtype) # load SD pipeline noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype) pipe = StableDiffusionControlNetPipeline.from_pretrained( base_model_path, controlnet=[pose_controlnet], torch_dtype=torch_dtype, scheduler=noise_scheduler, vae=vae, # feature_extractor=None, # safety_checker=None, ) # load uniportrait pipeline uniportrait_pipeline = UniPortraitPipeline(pipe, image_encoder_path, ip_ckpt=ip_ckpt, face_backbone_ckpt=face_backbone_ckpt, uniportrait_faceid_ckpt=uniportrait_faceid_ckpt, uniportrait_router_ckpt=uniportrait_router_ckpt, device=device, torch_dtype=torch_dtype) # load face detection assets face_app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=["detection"]) face_app.prepare(ctx_id=0, det_size=(640, 640)) def pad_np_bgr_image(np_image, scale=1.25): assert scale >= 1.0, "scale should be >= 1.0" pad_scale = scale - 1.0 h, w = np_image.shape[:2] top = bottom = int(h * pad_scale) left = right = int(w * pad_scale) ret = cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)) return ret, (left, top) def process_faceid_image(pil_faceid_image): np_faceid_image = np.array(pil_faceid_image.convert("RGB")) img = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR) faces = face_app.get(img) # bgr if len(faces) == 0: # padding, try again _h, _w = img.shape[:2] _img, left_top_coord = pad_np_bgr_image(img) faces = face_app.get(_img) if len(faces) == 0: gr.Info("Warning: No face detected in the image. Continue processing...") min_coord = np.array([0, 0]) max_coord = np.array([_w, _h]) sub_coord = np.array([left_top_coord[0], left_top_coord[1]]) for face in faces: face.bbox = np.minimum(np.maximum(face.bbox.reshape(-1, 2) - sub_coord, min_coord), max_coord).reshape(4) face.kps = face.kps - sub_coord faces = sorted(faces, key=lambda x: abs((x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])), reverse=True) faceid_face = faces[0] norm_face = face_align.norm_crop(img, landmark=faceid_face.kps, image_size=224) pil_faceid_align_image = Image.fromarray(cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB)) return pil_faceid_align_image def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_supp_images=None, pil_faceid_mix_images=None, mix_scales=None): pil_faceid_align_images = [] if pil_faceid_image: pil_faceid_align_images.append(process_faceid_image(pil_faceid_image)) if pil_faceid_supp_images and len(pil_faceid_supp_images) > 0: for pil_faceid_supp_image in pil_faceid_supp_images: if isinstance(pil_faceid_supp_image, Image.Image): pil_faceid_align_images.append(process_faceid_image(pil_faceid_supp_image)) else: pil_faceid_align_images.append( process_faceid_image(Image.open(BytesIO(pil_faceid_supp_image))) ) mix_refs = [] mix_ref_scales = [] if pil_faceid_mix_images: for pil_faceid_mix_image, mix_scale in zip(pil_faceid_mix_images, mix_scales): if pil_faceid_mix_image: mix_refs.append(process_faceid_image(pil_faceid_mix_image)) mix_ref_scales.append(mix_scale) single_faceid_cond_kwargs = None if len(pil_faceid_align_images) > 0: single_faceid_cond_kwargs = { "refs": pil_faceid_align_images } if len(mix_refs) > 0: single_faceid_cond_kwargs["mix_refs"] = mix_refs single_faceid_cond_kwargs["mix_scales"] = mix_ref_scales return single_faceid_cond_kwargs def text_to_single_id_generation_process( pil_faceid_image=None, pil_faceid_supp_images=None, pil_faceid_mix_image_1=None, mix_scale_1=0.0, pil_faceid_mix_image_2=None, mix_scale_2=0.0, faceid_scale=0.0, face_structure_scale=0.0, prompt="", negative_prompt="", num_samples=1, seed=-1, image_resolution="512x512", inference_steps=25, ): if seed == -1: seed = None single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image, pil_faceid_supp_images, [pil_faceid_mix_image_1, pil_faceid_mix_image_2], [mix_scale_1, mix_scale_2]) cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else [] # reset attn args attn_args.reset() # set faceid condition attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0 attn_args.num_faceids = len(cond_faceids) print(attn_args) h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1]) prompt = [prompt] * num_samples negative_prompt = [negative_prompt] * num_samples images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt, cond_faceids=cond_faceids, face_structure_scale=face_structure_scale, seed=seed, guidance_scale=7.5, num_inference_steps=inference_steps, image=[torch.zeros([1, 3, h, w])], controlnet_conditioning_scale=[0.0]) final_out = [] for pil_image in images: final_out.append(pil_image) for single_faceid_cond_kwargs in cond_faceids: final_out.extend(single_faceid_cond_kwargs["refs"]) if "mix_refs" in single_faceid_cond_kwargs: final_out.extend(single_faceid_cond_kwargs["mix_refs"]) return final_out def text_to_multi_id_generation_process( pil_faceid_image_1=None, pil_faceid_supp_images_1=None, pil_faceid_mix_image_1_1=None, mix_scale_1_1=0.0, pil_faceid_mix_image_1_2=None, mix_scale_1_2=0.0, pil_faceid_image_2=None, pil_faceid_supp_images_2=None, pil_faceid_mix_image_2_1=None, mix_scale_2_1=0.0, pil_faceid_mix_image_2_2=None, mix_scale_2_2=0.0, faceid_scale=0.0, face_structure_scale=0.0, prompt="", negative_prompt="", num_samples=1, seed=-1, image_resolution="512x512", inference_steps=25, ): if seed == -1: seed = None faceid_cond_kwargs_1 = prepare_single_faceid_cond_kwargs(pil_faceid_image_1, pil_faceid_supp_images_1, [pil_faceid_mix_image_1_1, pil_faceid_mix_image_1_2], [mix_scale_1_1, mix_scale_1_2]) faceid_cond_kwargs_2 = prepare_single_faceid_cond_kwargs(pil_faceid_image_2, pil_faceid_supp_images_2, [pil_faceid_mix_image_2_1, pil_faceid_mix_image_2_2], [mix_scale_2_1, mix_scale_2_2]) cond_faceids = [] if faceid_cond_kwargs_1 is not None: cond_faceids.append(faceid_cond_kwargs_1) if faceid_cond_kwargs_2 is not None: cond_faceids.append(faceid_cond_kwargs_2) # reset attn args attn_args.reset() # set faceid condition attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # single-faceid lora attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # multi-faceid lora attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0 attn_args.num_faceids = len(cond_faceids) print(attn_args) h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1]) prompt = [prompt] * num_samples negative_prompt = [negative_prompt] * num_samples images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt, cond_faceids=cond_faceids, face_structure_scale=face_structure_scale, seed=seed, guidance_scale=7.5, num_inference_steps=inference_steps, image=[torch.zeros([1, 3, h, w])], controlnet_conditioning_scale=[0.0]) final_out = [] for pil_image in images: final_out.append(pil_image) for single_faceid_cond_kwargs in cond_faceids: final_out.extend(single_faceid_cond_kwargs["refs"]) if "mix_refs" in single_faceid_cond_kwargs: final_out.extend(single_faceid_cond_kwargs["mix_refs"]) return final_out def image_to_single_id_generation_process( pil_faceid_image=None, pil_faceid_supp_images=None, pil_faceid_mix_image_1=None, mix_scale_1=0.0, pil_faceid_mix_image_2=None, mix_scale_2=0.0, faceid_scale=0.0, face_structure_scale=0.0, pil_ip_image=None, ip_scale=1.0, num_samples=1, seed=-1, image_resolution="768x512", inference_steps=25, ): if seed == -1: seed = None single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image, pil_faceid_supp_images, [pil_faceid_mix_image_1, pil_faceid_mix_image_2], [mix_scale_1, mix_scale_2]) cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else [] h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1]) # Image Prompt and Style Aligned if pil_ip_image is None: gr.Error("Please upload a reference image") attn_args.reset() pil_ip_image = pil_ip_image.convert("RGB").resize((w, h)) zts = inversion.ddim_inversion(uniportrait_pipeline.pipe, np.array(pil_ip_image), "", inference_steps, 2) zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0) # reset attn args attn_args.reset() # set ip condition attn_args.ip_scale = ip_scale if pil_ip_image else 0.0 # set faceid condition attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0 # lora for single faceid attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0 # lora for >1 faceids attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0 attn_args.num_faceids = len(cond_faceids) # set shared self-attn attn_args.enable_share_attn = True attn_args.shared_score_shift = -0.5 print(attn_args) prompt = [""] * (1 + num_samples) negative_prompt = [""] * (1 + num_samples) images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt, pil_ip_image=pil_ip_image, cond_faceids=cond_faceids, face_structure_scale=face_structure_scale, seed=seed, guidance_scale=7.5, num_inference_steps=inference_steps, image=[torch.zeros([1, 3, h, w])], controlnet_conditioning_scale=[0.0], zT=zT, callback_on_step_end=inversion_callback) images = images[1:] final_out = [] for pil_image in images: final_out.append(pil_image) for single_faceid_cond_kwargs in cond_faceids: final_out.extend(single_faceid_cond_kwargs["refs"]) if "mix_refs" in single_faceid_cond_kwargs: final_out.extend(single_faceid_cond_kwargs["mix_refs"]) return final_out def text_to_single_id_generation_block(): gr.Markdown("## Text-to-Single-ID Generation") gr.HTML(text_to_single_id_description) gr.HTML(text_to_single_id_tips) with gr.Row(): with gr.Column(scale=1, min_width=100): prompt = gr.Textbox(value="", label='Prompt', lines=2) negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt') run_button = gr.Button(value="Run") with gr.Accordion("Options", open=True): image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512", label="Image Resolution (HxW)") seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1, value=2147483647) num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False) faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.1) with gr.Column(scale=2, min_width=100): with gr.Row(equal_height=False): pil_faceid_image = gr.Image(type="pil", label="ID Image") with gr.Accordion("ID Supplements", open=True): with gr.Row(): pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"], type="binary", label="Additional ID Images") with gr.Row(): with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1") mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2") mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Row(): example_output = gr.Image(type="pil", label="(Example Output)", visible=False) result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True, format="png") with gr.Row(): examples = [ [ "A young man with short black hair, wearing a black hoodie with a hood, was paired with a blue denim jacket with yellow details.", "assets/examples/1-newton.jpg", "assets/examples/1-output-1.png", ], ] gr.Examples( label="Examples", examples=examples, fn=lambda x, y, z: (x, y), inputs=[prompt, pil_faceid_image, example_output], outputs=[prompt, pil_faceid_image] ) ips = [ pil_faceid_image, pil_faceid_supp_images, pil_faceid_mix_image_1, mix_scale_1, pil_faceid_mix_image_2, mix_scale_2, faceid_scale, face_structure_scale, prompt, negative_prompt, num_samples, seed, image_resolution, inference_steps, ] run_button.click(fn=text_to_single_id_generation_process, inputs=ips, outputs=[result_gallery]) def text_to_multi_id_generation_block(): gr.Markdown("## Text-to-Multi-ID Generation") gr.HTML(text_to_multi_id_description) gr.HTML(text_to_multi_id_tips) with gr.Row(): with gr.Column(scale=1, min_width=100): prompt = gr.Textbox(value="", label='Prompt', lines=2) negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt') run_button = gr.Button(value="Run") with gr.Accordion("Options", open=True): image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512", label="Image Resolution (HxW)") seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1, value=2147483647) num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False) faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.3) with gr.Column(scale=2, min_width=100): with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=100): pil_faceid_image_1 = gr.Image(type="pil", label="First ID") with gr.Accordion("First ID Supplements", open=False): with gr.Row(): pil_faceid_supp_images_1 = gr.File(file_count="multiple", file_types=["image"], type="binary", label="Additional ID Images") with gr.Row(): with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_1_1 = gr.Image(type="pil", label="Mix ID 1") mix_scale_1_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_1_2 = gr.Image(type="pil", label="Mix ID 2") mix_scale_1_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Column(scale=1, min_width=100): pil_faceid_image_2 = gr.Image(type="pil", label="Second ID") with gr.Accordion("Second ID Supplements", open=False): with gr.Row(): pil_faceid_supp_images_2 = gr.File(file_count="multiple", file_types=["image"], type="binary", label="Additional ID Images") with gr.Row(): with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_2_1 = gr.Image(type="pil", label="Mix ID 1") mix_scale_2_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_2_2 = gr.Image(type="pil", label="Mix ID 2") mix_scale_2_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Row(): example_output = gr.Image(type="pil", label="(Example Output)", visible=False) result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True, format="png") with gr.Row(): examples = [ [ "The two female models, fair-skinned, wore a white V-neck short-sleeved top with a light smile on the corners of their mouths. The background was off-white.", "assets/examples/2-stylegan2-ffhq-0100.png", "assets/examples/2-stylegan2-ffhq-0293.png", "assets/examples/2-output-1.png", ], ] gr.Examples( label="Examples", examples=examples, inputs=[prompt, pil_faceid_image_1, pil_faceid_image_2, example_output], ) ips = [ pil_faceid_image_1, pil_faceid_supp_images_1, pil_faceid_mix_image_1_1, mix_scale_1_1, pil_faceid_mix_image_1_2, mix_scale_1_2, pil_faceid_image_2, pil_faceid_supp_images_2, pil_faceid_mix_image_2_1, mix_scale_2_1, pil_faceid_mix_image_2_2, mix_scale_2_2, faceid_scale, face_structure_scale, prompt, negative_prompt, num_samples, seed, image_resolution, inference_steps, ] run_button.click(fn=text_to_multi_id_generation_process, inputs=ips, outputs=[result_gallery]) def image_to_single_id_generation_block(): gr.Markdown("## Image-to-Single-ID Generation") gr.HTML(image_to_single_id_description) gr.HTML(image_to_single_id_tips) with gr.Row(): with gr.Column(scale=1, min_width=100): run_button = gr.Button(value="Run") seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1, value=2147483647) num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512", label="Image Resolution (HxW)") inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False) ip_scale = gr.Slider(label="Reference Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7) face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.3) with gr.Column(scale=3, min_width=100): with gr.Row(equal_height=False): pil_ip_image = gr.Image(type="pil", label="Portrait Reference") pil_faceid_image = gr.Image(type="pil", label="ID Image") with gr.Accordion("ID Supplements", open=True): with gr.Row(): pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"], type="binary", label="Additional ID Images") with gr.Row(): with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1") mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Column(scale=1, min_width=100): pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2") mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0) with gr.Row(): with gr.Column(scale=3, min_width=100): example_output = gr.Image(type="pil", label="(Example Output)", visible=False) result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True, format="png") with gr.Row(): examples = [ [ "assets/examples/3-style-1.png", "assets/examples/3-stylegan2-ffhq-0293.png", 0.7, 0.3, "assets/examples/3-output-1.png", ], [ "assets/examples/3-style-1.png", "assets/examples/3-stylegan2-ffhq-0293.png", 0.6, 0.0, "assets/examples/3-output-2.png", ], [ "assets/examples/3-style-2.jpg", "assets/examples/3-stylegan2-ffhq-0381.png", 0.7, 0.3, "assets/examples/3-output-3.png", ], [ "assets/examples/3-style-3.jpg", "assets/examples/3-stylegan2-ffhq-0381.png", 0.6, 0.0, "assets/examples/3-output-4.png", ], ] gr.Examples( label="Examples", examples=examples, fn=lambda x, y, z, w, v: (x, y, z, w), inputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale, example_output], outputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale] ) ips = [ pil_faceid_image, pil_faceid_supp_images, pil_faceid_mix_image_1, mix_scale_1, pil_faceid_mix_image_2, mix_scale_2, faceid_scale, face_structure_scale, pil_ip_image, ip_scale, num_samples, seed, image_resolution, inference_steps, ] run_button.click(fn=image_to_single_id_generation_process, inputs=ips, outputs=[result_gallery]) if __name__ == "__main__": os.environ["no_proxy"] = "localhost,127.0.0.1,::1" title = r"""

UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization

  Project Page  

""" title_description = r""" This is the official 🤗 Gradio demo for UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization.
The demo provides three capabilities: text-to-single-ID personalization, text-to-multi-ID personalization, and image-to-single-ID personalization. All of these are based on the Stable Diffusion v1-5 model. Feel free to give them a try! 😊 """ text_to_single_id_description = r"""🚀🚀🚀Quick start:
1. Enter a text prompt (Chinese or English), Upload an image with a face, and Click the Run button. 🤗
""" text_to_single_id_tips = r"""💡💡💡Tips:
1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".
3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).
""" text_to_multi_id_description = r"""🚀🚀🚀Quick start:
1. Enter a text prompt (Chinese or English), Upload an image with a face in "First ID" and "Second ID" blocks respectively, and Click the Run button. 🤗
""" text_to_multi_id_tips = r"""💡💡💡Tips:
1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".
3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.3~0.7) and "Face Structure Scale" (0.0~0.4).
""" image_to_single_id_description = r"""🚀🚀🚀Quick start: Upload an image as the portrait reference (can be any style), Upload a face image, and Click the Run button. 🤗
""" image_to_single_id_tips = r"""💡💡💡Tips:
1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)
2. It's a good idea to upload multiple reference photos of your face to improve ID consistency. Additional references can be uploaded in the "ID supplements".
3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the portrait reference and ID alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).
""" citation = r""" --- 📝 **Citation**
If our work is helpful for your research or applications, please cite us via: ```bibtex @article{he2024uniportrait, title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization}, author={He, Junjie and Geng, Yifeng and Bo, Liefeng}, journal={arXiv preprint arXiv:2408.05939}, year={2024} } ``` 📧 **Contact**
If you have any questions, please feel free to open an issue or directly reach us out at hejunjie1103@gmail.com. """ block = gr.Blocks(title="UniPortrait").queue() with block: gr.HTML(title) gr.HTML(title_description) with gr.TabItem("Text-to-Single-ID"): text_to_single_id_generation_block() with gr.TabItem("Text-to-Multi-ID"): text_to_multi_id_generation_block() with gr.TabItem("Image-to-Single-ID (Stylization)"): image_to_single_id_generation_block() gr.Markdown(citation) block.launch(server_name='0.0.0.0', share=False, server_port=port, allowed_paths=["/"]) ================================================ FILE: requirements.txt ================================================ diffusers gradio onnxruntime-gpu insightface torch tqdm transformers ================================================ FILE: uniportrait/__init__.py ================================================ ================================================ FILE: uniportrait/curricular_face/__init__.py ================================================ ================================================ FILE: uniportrait/curricular_face/backbone/__init__.py ================================================ # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200) from .model_resnet import ResNet_50, ResNet_101, ResNet_152 _model_dict = { 'ResNet_50': ResNet_50, 'ResNet_101': ResNet_101, 'ResNet_152': ResNet_152, 'IR_18': IR_18, 'IR_34': IR_34, 'IR_50': IR_50, 'IR_101': IR_101, 'IR_152': IR_152, 'IR_200': IR_200, 'IR_SE_50': IR_SE_50, 'IR_SE_101': IR_SE_101, 'IR_SE_152': IR_SE_152, 'IR_SE_200': IR_SE_200 } def get_model(key): """ Get different backbone network by key, support ResNet50, ResNet_101, ResNet_152 IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. """ if key in _model_dict.keys(): return _model_dict[key] else: raise KeyError('not support model {}'.format(key)) ================================================ FILE: uniportrait/curricular_face/backbone/common.py ================================================ # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py import torch.nn as nn from torch.nn import (Conv2d, Module, ReLU, Sigmoid) def initialize_weights(modules): """ Weight initilize, conv2d and linear is initialized with kaiming_normal """ for m in modules: if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() class Flatten(Module): """ Flat tensor """ def forward(self, input): return input.view(input.size(0), -1) class SEModule(Module): """ SE block """ def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = Conv2d( channels, channels // reduction, kernel_size=1, padding=0, bias=False) nn.init.xavier_uniform_(self.fc1.weight.data) self.relu = ReLU(inplace=True) self.fc2 = Conv2d( channels // reduction, channels, kernel_size=1, padding=0, bias=False) self.sigmoid = Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x ================================================ FILE: uniportrait/curricular_face/backbone/model_irse.py ================================================ # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py from collections import namedtuple from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential) from .common import Flatten, SEModule, initialize_weights class BasicBlockIR(Module): """ BasicBlock for IRNet """ def __init__(self, in_channel, depth, stride): super(BasicBlockIR, self).__init__() if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), BatchNorm2d(depth), PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class BottleneckIR(Module): """ BasicBlock with bottleneck for IRNet """ def __init__(self, in_channel, depth, stride): super(BottleneckIR, self).__init__() reduction_channel = depth // 4 if in_channel == depth: self.shortcut_layer = MaxPool2d(1, stride) else: self.shortcut_layer = Sequential( Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) self.res_layer = Sequential( BatchNorm2d(in_channel), Conv2d( in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), BatchNorm2d(reduction_channel), PReLU(reduction_channel), Conv2d( reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False), BatchNorm2d(reduction_channel), PReLU(reduction_channel), Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), BatchNorm2d(depth)) def forward(self, x): shortcut = self.shortcut_layer(x) res = self.res_layer(x) return res + shortcut class BasicBlockIRSE(BasicBlockIR): def __init__(self, in_channel, depth, stride): super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) self.res_layer.add_module('se_block', SEModule(depth, 16)) class BottleneckIRSE(BottleneckIR): def __init__(self, in_channel, depth, stride): super(BottleneckIRSE, self).__init__(in_channel, depth, stride) self.res_layer.add_module('se_block', SEModule(depth, 16)) class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): '''A named tuple describing a ResNet block.''' def get_block(in_channel, depth, num_units, stride=2): return [Bottleneck(in_channel, depth, stride)] + \ [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] def get_blocks(num_layers): if num_layers == 18: blocks = [ get_block(in_channel=64, depth=64, num_units=2), get_block(in_channel=64, depth=128, num_units=2), get_block(in_channel=128, depth=256, num_units=2), get_block(in_channel=256, depth=512, num_units=2) ] elif num_layers == 34: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=128, depth=256, num_units=6), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 50: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=4), get_block(in_channel=128, depth=256, num_units=14), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 100: blocks = [ get_block(in_channel=64, depth=64, num_units=3), get_block(in_channel=64, depth=128, num_units=13), get_block(in_channel=128, depth=256, num_units=30), get_block(in_channel=256, depth=512, num_units=3) ] elif num_layers == 152: blocks = [ get_block(in_channel=64, depth=256, num_units=3), get_block(in_channel=256, depth=512, num_units=8), get_block(in_channel=512, depth=1024, num_units=36), get_block(in_channel=1024, depth=2048, num_units=3) ] elif num_layers == 200: blocks = [ get_block(in_channel=64, depth=256, num_units=3), get_block(in_channel=256, depth=512, num_units=24), get_block(in_channel=512, depth=1024, num_units=36), get_block(in_channel=1024, depth=2048, num_units=3) ] return blocks class Backbone(Module): def __init__(self, input_size, num_layers, mode='ir'): """ Args: input_size: input_size of backbone num_layers: num_layers of backbone mode: support ir or irse """ super(Backbone, self).__init__() assert input_size[0] in [112, 224], \ 'input_size should be [112, 112] or [224, 224]' assert num_layers in [18, 34, 50, 100, 152, 200], \ 'num_layers should be 18, 34, 50, 100 or 152' assert mode in ['ir', 'ir_se'], \ 'mode should be ir or ir_se' self.input_layer = Sequential( Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) blocks = get_blocks(num_layers) if num_layers <= 100: if mode == 'ir': unit_module = BasicBlockIR elif mode == 'ir_se': unit_module = BasicBlockIRSE output_channel = 512 else: if mode == 'ir': unit_module = BottleneckIR elif mode == 'ir_se': unit_module = BottleneckIRSE output_channel = 2048 if input_size[0] == 112: self.output_layer = Sequential( BatchNorm2d(output_channel), Dropout(0.4), Flatten(), Linear(output_channel * 7 * 7, 512), BatchNorm1d(512, affine=False)) else: self.output_layer = Sequential( BatchNorm2d(output_channel), Dropout(0.4), Flatten(), Linear(output_channel * 14 * 14, 512), BatchNorm1d(512, affine=False)) modules = [] mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101 for block in blocks: if len(mid_layer_indices) == 0: mid_layer_indices.append(len(block) - 1) else: mid_layer_indices.append(len(block) + mid_layer_indices[-1]) for bottleneck in block: modules.append( unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.mid_layer_indices = mid_layer_indices[-4:] initialize_weights(self.modules()) def forward(self, x, return_mid_feats=False): x = self.input_layer(x) if not return_mid_feats: x = self.body(x) x = self.output_layer(x) return x else: out_feats = [] for idx, module in enumerate(self.body): x = module(x) if idx in self.mid_layer_indices: out_feats.append(x) x = self.output_layer(x) return x, out_feats def IR_18(input_size): """ Constructs a ir-18 model. """ model = Backbone(input_size, 18, 'ir') return model def IR_34(input_size): """ Constructs a ir-34 model. """ model = Backbone(input_size, 34, 'ir') return model def IR_50(input_size): """ Constructs a ir-50 model. """ model = Backbone(input_size, 50, 'ir') return model def IR_101(input_size): """ Constructs a ir-101 model. """ model = Backbone(input_size, 100, 'ir') return model def IR_152(input_size): """ Constructs a ir-152 model. """ model = Backbone(input_size, 152, 'ir') return model def IR_200(input_size): """ Constructs a ir-200 model. """ model = Backbone(input_size, 200, 'ir') return model def IR_SE_50(input_size): """ Constructs a ir_se-50 model. """ model = Backbone(input_size, 50, 'ir_se') return model def IR_SE_101(input_size): """ Constructs a ir_se-101 model. """ model = Backbone(input_size, 100, 'ir_se') return model def IR_SE_152(input_size): """ Constructs a ir_se-152 model. """ model = Backbone(input_size, 152, 'ir_se') return model def IR_SE_200(input_size): """ Constructs a ir_se-200 model. """ model = Backbone(input_size, 200, 'ir_se') return model ================================================ FILE: uniportrait/curricular_face/backbone/model_resnet.py ================================================ # The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at # https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py import torch.nn as nn from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential) from .common import initialize_weights def conv3x3(in_planes, out_planes, stride=1): """ 3x3 convolution with padding """ return Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x1(in_planes, out_planes, stride=1): """ 1x1 convolution """ return Conv2d( in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class Bottleneck(Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = conv1x1(inplanes, planes) self.bn1 = BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes, stride) self.bn2 = BatchNorm2d(planes) self.conv3 = conv1x1(planes, planes * self.expansion) self.bn3 = BatchNorm2d(planes * self.expansion) self.relu = ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(Module): """ ResNet backbone """ def __init__(self, input_size, block, layers, zero_init_residual=True): """ Args: input_size: input_size of backbone block: block function layers: layers in each block """ super(ResNet, self).__init__() assert input_size[0] in [112, 224], \ 'input_size should be [112, 112] or [224, 224]' self.inplanes = 64 self.conv1 = Conv2d( 3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = BatchNorm2d(64) self.relu = ReLU(inplace=True) self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.bn_o1 = BatchNorm2d(2048) self.dropout = Dropout() if input_size[0] == 112: self.fc = Linear(2048 * 4 * 4, 512) else: self.fc = Linear(2048 * 7 * 7, 512) self.bn_o2 = BatchNorm1d(512) initialize_weights(self.modules) if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.bn_o1(x) x = self.dropout(x) x = x.view(x.size(0), -1) x = self.fc(x) x = self.bn_o2(x) return x def ResNet_50(input_size, **kwargs): """ Constructs a ResNet-50 model. """ model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) return model def ResNet_101(input_size, **kwargs): """ Constructs a ResNet-101 model. """ model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) return model def ResNet_152(input_size, **kwargs): """ Constructs a ResNet-152 model. """ model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) return model ================================================ FILE: uniportrait/curricular_face/inference.py ================================================ import glob import os import cv2 import numpy as np import torch from tqdm.auto import tqdm from .backbone import get_model @torch.no_grad() def inference(name, weight, src_norm_dir): face_model = get_model(name)([112, 112]) face_model.load_state_dict(torch.load(weight, map_location="cpu")) face_model = face_model.to("cpu") face_model.eval() id2src_norm = {} for src_id in sorted(list(os.listdir(src_norm_dir))): id2src_norm[src_id] = sorted(list(glob.glob(f"{os.path.join(src_norm_dir, src_id)}/*"))) total_sims = [] for id_name in tqdm(id2src_norm): src_face_embeddings = [] for src_img_path in id2src_norm[id_name]: src_img = cv2.imread(src_img_path) src_img = cv2.resize(src_img, (112, 112)) src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) src_img = np.transpose(src_img, (2, 0, 1)) src_img = torch.from_numpy(src_img).unsqueeze(0).float() src_img.div_(255).sub_(0.5).div_(0.5) embedding = face_model(src_img).detach().cpu().numpy()[0] embedding = embedding / np.linalg.norm(embedding) src_face_embeddings.append(embedding) # 512 num = len(src_face_embeddings) src_face_embeddings = np.stack(src_face_embeddings) # n, 512 sim = src_face_embeddings @ src_face_embeddings.T # n, n mean_sim = (np.sum(sim) - num * 1.0) / ((num - 1) * num) print(f"{id_name}: {mean_sim}") total_sims.append(mean_sim) return np.mean(total_sims) if __name__ == "__main__": name = 'IR_101' weight = "models/glint360k_curricular_face_r101_backbone.bin" src_norm_dir = "/disk1/hejunjie.hjj/data/normface-AFD-id-20" mean_sim = inference(name, weight, src_norm_dir) print(f"total: {mean_sim:.4f}") # total: 0.6299 ================================================ FILE: uniportrait/inversion.py ================================================ # modified from https://github.com/google/style-aligned/blob/main/inversion.py from __future__ import annotations from typing import Callable import numpy as np import torch from diffusers import StableDiffusionPipeline from tqdm import tqdm T = torch.Tensor InversionCallback = Callable[[StableDiffusionPipeline, int, T, dict[str, T]], dict[str, T]] def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: str) -> tuple[dict[str, T], T]: device = model._execution_device prompt_embeds = model._encode_prompt( prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt="") return prompt_embeds def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T: model.vae.to(dtype=torch.float32) image = torch.from_numpy(image).float() / 255. image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0) latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor model.vae.to(dtype=torch.float16) return latent def _next_step(model: StableDiffusionPipeline, model_output: T, timestep: int, sample: T) -> T: timestep, next_timestep = min( timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep alpha_prod_t = model.scheduler.alphas_cumprod[ int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)] beta_prod_t = 1 - alpha_prod_t next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction return next_sample def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, context: T, guidance_scale: float): latents_input = torch.cat([latent] * 2) noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # latents = next_step(model, noise_pred, t, latent) return noise_pred def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scale) -> T: all_latent = [z0] text_embedding = _encode_text_with_negative(model, prompt) image_embedding = torch.zeros_like(text_embedding[:, :1]).repeat(1, 4, 1) # for ip embedding text_embedding = torch.cat([text_embedding, image_embedding], dim=1) latent = z0.clone().detach().half() for i in tqdm(range(model.scheduler.num_inference_steps)): t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1] noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale) latent = _next_step(model, noise_pred, t, latent) all_latent.append(latent) return torch.cat(all_latent).flip(0) def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]: def callback_on_step_end(pipeline: StableDiffusionPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[ str, T]: latents = callback_kwargs['latents'] latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype) return {'latents': latents} return zts[offset], callback_on_step_end @torch.no_grad() def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale, ) -> T: z0 = _encode_image(model, x0) model.scheduler.set_timesteps(num_inference_steps, device=z0.device) zs = _ddim_loop(model, z0, prompt, guidance_scale) return zs ================================================ FILE: uniportrait/resampler.py ================================================ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py import math import torch import torch.nn as nn # FFN def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) def reshape_tensor(x, heads): bs, length, width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, heads, length, -1) return x class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head ** -0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents, attention_mask=None): """ Args: x (torch.Tensor): image features shape (b, n1, D) latents (torch.Tensor): latent features shape (b, n2, D) attention_mask (torch.Tensor): attention mask shape (b, n1, 1) """ x = self.norm1(x) latents = self.norm2(latents) b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards if attention_mask is not None: attention_mask = attention_mask.transpose(1, 2) # (b, 1, n1) attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :, :1]).repeat(1, 1, l)], dim=2) # b, 1, n1+n2 attention_mask = (attention_mask - 1.) * 100. # 0 means kept and -100 means dropped attention_mask = attention_mask.unsqueeze(1) weight = weight + attention_mask # b, h, n2, n1+n2 weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class UniPortraitFaceIDResampler(torch.nn.Module): def __init__( self, intrinsic_id_embedding_dim=512, structure_embedding_dim=64 + 128 + 256 + 1280, num_tokens=16, depth=6, dim=768, dim_head=64, heads=12, ff_mult=4, output_dim=768, ): super().__init__() self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5) self.proj_id = torch.nn.Sequential( torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2), torch.nn.GELU(), torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim), ) self.proj_clip = torch.nn.Sequential( torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2), torch.nn.GELU(), torch.nn.Linear(structure_embedding_dim * 2, dim), ) self.layers = torch.nn.ModuleList([]) for _ in range(depth): self.layers.append( torch.nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) self.proj_out = torch.nn.Linear(dim, output_dim) self.norm_out = torch.nn.LayerNorm(output_dim) def forward( self, intrinsic_id_embeds, structure_embeds, structure_scale=1.0, intrinsic_id_attention_mask=None, structure_attention_mask=None ): latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1) intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds) structure_embeds = self.proj_clip(structure_embeds) for attn1, attn2, ff in self.layers: latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) ================================================ FILE: uniportrait/uniportrait_attention_processor.py ================================================ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.lora import LoRALinearLayer class AttentionArgs(object): def __init__(self) -> None: # ip condition self.ip_scale = 0.0 self.ip_mask = None # ip attention mask # faceid condition self.lora_scale = 0.0 # lora for single faceid self.multi_id_lora_scale = 0.0 # lora for multiple faceids self.faceid_scale = 0.0 self.num_faceids = 0 self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map # style aligned self.enable_share_attn: bool = False self.adain_queries_and_keys: bool = False self.shared_score_scale: float = 1.0 self.shared_score_shift: float = 0.0 def reset(self): # ip condition self.ip_scale = 0.0 self.ip_mask = None # ip attention mask # faceid condition self.lora_scale = 0.0 # lora for single faceid self.multi_id_lora_scale = 0.0 # lora for multiple faceids self.faceid_scale = 0.0 self.num_faceids = 0 self.faceid_mask = None # faceid attention mask; if not None, it will override the routing map # style aligned self.enable_share_attn: bool = False self.adain_queries_and_keys: bool = False self.shared_score_scale: float = 1.0 self.shared_score_shift: float = 0.0 def __repr__(self): indent_str = ' ' s = f",\n{indent_str}".join(f"{attr}={value}" for attr, value in vars(self).items()) return self.__class__.__name__ + '(' + f'\n{indent_str}' + s + ')' attn_args = AttentionArgs() def expand_first(feat, scale=1., ): b = feat.shape[0] feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) if scale == 1: feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) else: feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) return feat_style.reshape(*feat.shape) def concat_first(feat, dim=2, scale=1.): feat_style = expand_first(feat, scale=scale) return torch.cat((feat, feat_style), dim=dim) def calc_mean_std(feat, eps: float = 1e-5): feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() feat_mean = feat.mean(dim=-2, keepdims=True) return feat_mean, feat_std def adain(feat): feat_mean, feat_std = calc_mean_std(feat) feat_style_mean = expand_first(feat_mean) feat_style_std = expand_first(feat_std) feat = (feat - feat_mean) / feat_std feat = feat * feat_style_std + feat_style_mean return feat class UniPortraitLoRAAttnProcessor2_0(nn.Module): def __init__( self, hidden_size=None, cross_attention_dim=None, rank=128, network_alpha=None, ): super().__init__() self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if attn_args.lora_scale > 0.0: query = query + attn_args.lora_scale * self.to_q_lora(hidden_states) key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states) value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states) elif attn_args.multi_id_lora_scale > 0.0: query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states) key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states) value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn_args.enable_share_attn: if attn_args.adain_queries_and_keys: query = adain(query) key = adain(key) key = concat_first(key, -2, scale=attn_args.shared_score_scale) value = concat_first(value, -2) if attn_args.shared_score_shift != 0: attention_mask = torch.zeros_like(key[:, :, :, :1]).transpose(-1, -2) # b, h, 1, k attention_mask[:, :, :, query.shape[2]:] += attn_args.shared_score_shift hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj output_hidden_states = attn.to_out[0](hidden_states) if attn_args.lora_scale > 0.0: output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states) elif attn_args.multi_id_lora_scale > 0.0: output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora( hidden_states) hidden_states = output_hidden_states # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class UniPortraitLoRAIPAttnProcessor2_0(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None, num_ip_tokens=4, num_faceid_tokens=16): super().__init__() self.num_ip_tokens = num_ip_tokens self.num_faceid_tokens = num_faceid_tokens self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_k_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_q_router = nn.Sequential( nn.Linear(hidden_size, hidden_size * 2), nn.GELU(), nn.Linear(hidden_size * 2, hidden_size, bias=False), ) self.to_k_router = nn.Sequential( nn.Linear(cross_attention_dim or hidden_size, (cross_attention_dim or hidden_size) * 2), nn.GELU(), nn.Linear((cross_attention_dim or hidden_size) * 2, hidden_size, bias=False), ) self.aggr_router = nn.Linear(num_faceid_tokens, 1) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # split hidden states faceid_end = encoder_hidden_states.shape[1] ip_end = faceid_end - self.num_faceid_tokens * attn_args.num_faceids text_end = ip_end - self.num_ip_tokens prompt_hidden_states = encoder_hidden_states[:, :text_end] ip_hidden_states = encoder_hidden_states[:, text_end: ip_end] faceid_hidden_states = encoder_hidden_states[:, ip_end: faceid_end] encoder_hidden_states = prompt_hidden_states if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) # for router if attn_args.num_faceids > 1: router_query = self.to_q_router(hidden_states) # bs, s*s, dim router_hidden_states = faceid_hidden_states.reshape(batch_size, attn_args.num_faceids, self.num_faceid_tokens, -1) # bs, num, id_tokens, d router_hidden_states = self.aggr_router(router_hidden_states.transpose(-1, -2)).squeeze(-1) # bs, num, d router_key = self.to_k_router(router_hidden_states) # bs, num, dim router_logits = torch.bmm(router_query, router_key.transpose(-1, -2)) # bs, s*s, num index = router_logits.max(dim=-1, keepdim=True)[1] routing_map = torch.zeros_like(router_logits).scatter_(-1, index, 1.0) routing_map = routing_map.transpose(1, 2).unsqueeze(-1) # bs, num, s*s, 1 else: routing_map = hidden_states.new_ones(size=(1, 1, hidden_states.shape[1], 1)) # for text query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if attn_args.lora_scale > 0.0: query = query + attn_args.lora_scale * self.to_q_lora(hidden_states) key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states) value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states) elif attn_args.multi_id_lora_scale > 0.0: query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states) key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states) value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # for ip-adapter if attn_args.ip_scale > 0.0: ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) if attn_args.ip_mask is not None: ip_mask = attn_args.ip_mask h, w = ip_mask.shape[-2:] ratio = (h * w / query.shape[2]) ** 0.5 ip_mask = torch.nn.functional.interpolate(ip_mask, scale_factor=1 / ratio, mode='nearest').reshape( [1, -1, 1]) ip_hidden_states = ip_hidden_states * ip_mask if attn_args.enable_share_attn: ip_hidden_states[0] = 0. ip_hidden_states[batch_size // 2] = 0. else: ip_hidden_states = torch.zeros_like(hidden_states) # for faceid-adapter if attn_args.faceid_scale > 0.0: faceid_key = self.to_k_faceid(faceid_hidden_states) faceid_value = self.to_v_faceid(faceid_hidden_states) faceid_query = query[:, None].expand(-1, attn_args.num_faceids, -1, -1, -1) # 2*bs, num, heads, s*s, dim/heads faceid_key = faceid_key.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads, head_dim).transpose(2, 3) faceid_value = faceid_value.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads, head_dim).transpose(2, 3) faceid_hidden_states = F.scaled_dot_product_attention( faceid_query, faceid_key, faceid_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale ) # 2*bs, num, heads, s*s, dim/heads faceid_hidden_states = faceid_hidden_states.transpose(2, 3).reshape(batch_size, attn_args.num_faceids, -1, attn.heads * head_dim) faceid_hidden_states = faceid_hidden_states.to(query.dtype) # 2*bs, num, s*s, dim if attn_args.faceid_mask is not None: faceid_mask = attn_args.faceid_mask # 1, num, h, w h, w = faceid_mask.shape[-2:] ratio = (h * w / query.shape[2]) ** 0.5 faceid_mask = F.interpolate(faceid_mask, scale_factor=1 / ratio, mode='bilinear').flatten(2).unsqueeze(-1) # 1, num, s*s, 1 faceid_mask = faceid_mask / faceid_mask.sum(1, keepdim=True).clip(min=1e-3) # 1, num, s*s, 1 faceid_hidden_states = (faceid_mask * faceid_hidden_states).sum(1) # 2*bs, s*s, dim else: faceid_hidden_states = (routing_map * faceid_hidden_states).sum(1) # 2*bs, s*s, dim if attn_args.enable_share_attn: faceid_hidden_states[0] = 0. faceid_hidden_states[batch_size // 2] = 0. else: faceid_hidden_states = torch.zeros_like(hidden_states) hidden_states = hidden_states + \ attn_args.ip_scale * ip_hidden_states + \ attn_args.faceid_scale * faceid_hidden_states # linear proj output_hidden_states = attn.to_out[0](hidden_states) if attn_args.lora_scale > 0.0: output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states) elif attn_args.multi_id_lora_scale > 0.0: output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora( hidden_states) hidden_states = output_hidden_states # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # for controlnet class UniPortraitCNAttnProcessor2_0: def __init__(self, num_ip_tokens=4, num_faceid_tokens=16): self.num_ip_tokens = num_ip_tokens self.num_faceid_tokens = num_faceid_tokens def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: text_end = encoder_hidden_states.shape[1] - self.num_faceid_tokens * attn_args.num_faceids \ - self.num_ip_tokens encoder_hidden_states = encoder_hidden_states[:, :text_end] # only use text if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states ================================================ FILE: uniportrait/uniportrait_pipeline.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from diffusers import ControlNetModel from diffusers.pipelines.controlnet import MultiControlNetModel from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from .curricular_face.backbone import get_model from .resampler import UniPortraitFaceIDResampler from .uniportrait_attention_processor import UniPortraitCNAttnProcessor2_0 as UniPortraitCNAttnProcessor from .uniportrait_attention_processor import UniPortraitLoRAAttnProcessor2_0 as UniPortraitLoRAAttnProcessor from .uniportrait_attention_processor import UniPortraitLoRAIPAttnProcessor2_0 as UniPortraitLoRAIPAttnProcessor class ImageProjModel(nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): embeds = image_embeds # b, c clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class UniPortraitPipeline: def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_backbone_ckpt=None, uniportrait_faceid_ckpt=None, uniportrait_router_ckpt=None, num_ip_tokens=4, num_faceid_tokens=16, lora_rank=128, device=torch.device("cuda"), torch_dtype=torch.float16): self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.uniportrait_faceid_ckpt = uniportrait_faceid_ckpt self.uniportrait_router_ckpt = uniportrait_router_ckpt self.num_ip_tokens = num_ip_tokens self.num_faceid_tokens = num_faceid_tokens self.lora_rank = lora_rank self.device = device self.torch_dtype = torch_dtype self.pipe = sd_pipe.to(self.device) # load clip image encoder self.clip_image_processor = CLIPImageProcessor(size={"shortest_edge": 224}, do_center_crop=False, use_square_size=True) self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( self.device, dtype=self.torch_dtype) # load face backbone self.facerecog_model = get_model("IR_101")([112, 112]) self.facerecog_model.load_state_dict(torch.load(face_backbone_ckpt, map_location="cpu")) self.facerecog_model = self.facerecog_model.to(self.device, dtype=torch_dtype) self.facerecog_model.eval() # image proj model self.image_proj_model = self.init_image_proj() # faceid proj model self.faceid_proj_model = self.init_faceid_proj() # set uniportrait and ip adapter self.set_uniportrait_and_ip_adapter() # load uniportrait and ip adapter self.load_uniportrait_and_ip_adapter() def init_image_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, clip_embeddings_dim=self.clip_image_encoder.config.projection_dim, clip_extra_context_tokens=self.num_ip_tokens, ).to(self.device, dtype=self.torch_dtype) return image_proj_model def init_faceid_proj(self): faceid_proj_model = UniPortraitFaceIDResampler( intrinsic_id_embedding_dim=512, structure_embedding_dim=64 + 128 + 256 + self.clip_image_encoder.config.hidden_size, num_tokens=16, depth=6, dim=self.pipe.unet.config.cross_attention_dim, dim_head=64, heads=12, ff_mult=4, output_dim=self.pipe.unet.config.cross_attention_dim ).to(self.device, dtype=self.torch_dtype) return faceid_proj_model def set_uniportrait_and_ip_adapter(self): unet = self.pipe.unet attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = UniPortraitLoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, ).to(self.device, dtype=self.torch_dtype).eval() else: attn_procs[name] = UniPortraitLoRAIPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, num_ip_tokens=self.num_ip_tokens, num_faceid_tokens=self.num_faceid_tokens, ).to(self.device, dtype=self.torch_dtype).eval() unet.set_attn_processor(attn_procs) if hasattr(self.pipe, "controlnet"): if isinstance(self.pipe.controlnet, ControlNetModel): self.pipe.controlnet.set_attn_processor( UniPortraitCNAttnProcessor( num_ip_tokens=self.num_ip_tokens, num_faceid_tokens=self.num_faceid_tokens, ) ) elif isinstance(self.pipe.controlnet, MultiControlNetModel): for module in self.pipe.controlnet.nets: module.set_attn_processor( UniPortraitCNAttnProcessor( num_ip_tokens=self.num_ip_tokens, num_faceid_tokens=self.num_faceid_tokens, ) ) else: raise ValueError def load_uniportrait_and_ip_adapter(self): if self.ip_ckpt: print(f"loading from {self.ip_ckpt}...") state_dict = torch.load(self.ip_ckpt, map_location="cpu") self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False) ip_layers = nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) if self.uniportrait_faceid_ckpt: print(f"loading from {self.uniportrait_faceid_ckpt}...") state_dict = torch.load(self.uniportrait_faceid_ckpt, map_location="cpu") self.faceid_proj_model.load_state_dict(state_dict["faceid_proj"], strict=True) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["faceid_adapter"], strict=False) if self.uniportrait_router_ckpt: print(f"loading from {self.uniportrait_router_ckpt}...") state_dict = torch.load(self.uniportrait_router_ckpt, map_location="cpu") router_state_dict = {} for k, v in state_dict["faceid_adapter"].items(): if "lora." in k: router_state_dict[k.replace("lora.", "multi_id_lora.")] = v elif "router." in k: router_state_dict[k] = v ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(router_state_dict, strict=False) @torch.inference_mode() def get_ip_embeds(self, pil_ip_image): ip_image = self.clip_image_processor(images=pil_ip_image, return_tensors="pt").pixel_values ip_image = ip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224), values being normalized ip_embeds = self.clip_image_encoder(ip_image).image_embeds ip_prompt_embeds = self.image_proj_model(ip_embeds) uncond_ip_prompt_embeds = self.image_proj_model(torch.zeros_like(ip_embeds)) return ip_prompt_embeds, uncond_ip_prompt_embeds @torch.inference_mode() def get_single_faceid_embeds(self, pil_face_images, face_structure_scale): face_clip_image = self.clip_image_processor(images=pil_face_images, return_tensors="pt").pixel_values face_clip_image = face_clip_image.to(self.device, dtype=self.torch_dtype) # (b, 3, 224, 224) face_clip_embeds = self.clip_image_encoder( face_clip_image, output_hidden_states=True).hidden_states[-2][:, 1:] # b, 256, 1280 OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device, dtype=self.torch_dtype).reshape(-1, 1, 1) OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device, dtype=self.torch_dtype).reshape(-1, 1, 1) facerecog_image = face_clip_image * OPENAI_CLIP_STD + OPENAI_CLIP_MEAN # [0, 1] facerecog_image = torch.clamp((facerecog_image - 0.5) / 0.5, -1, 1) # [-1, 1] facerecog_image = F.interpolate(facerecog_image, size=(112, 112), mode="bilinear", align_corners=False) facerecog_embeds = self.facerecog_model(facerecog_image, return_mid_feats=True)[1] face_intrinsic_id_embeds = facerecog_embeds[-1] # (b, 512, 7, 7) face_intrinsic_id_embeds = face_intrinsic_id_embeds.flatten(2).permute(0, 2, 1) # b, 49, 512 facerecog_structure_embeds = facerecog_embeds[:-1] # (b, 64, 56, 56), (b, 128, 28, 28), (b, 256, 14, 14) facerecog_structure_embeds = torch.cat([ F.interpolate(feat, size=(16, 16), mode="bilinear", align_corners=False) for feat in facerecog_structure_embeds], dim=1) # b, 448, 16, 16 facerecog_structure_embeds = facerecog_structure_embeds.flatten(2).permute(0, 2, 1) # b, 256, 448 face_structure_embeds = torch.cat([facerecog_structure_embeds, face_clip_embeds], dim=-1) # b, 256, 1728 uncond_face_clip_embeds = self.clip_image_encoder( torch.zeros_like(face_clip_image[:1]), output_hidden_states=True).hidden_states[-2][:, 1:] # 1, 256, 1280 uncond_face_structure_embeds = torch.cat( [torch.zeros_like(facerecog_structure_embeds[:1]), uncond_face_clip_embeds], dim=-1) # 1, 256, 1728 faceid_prompt_embeds = self.faceid_proj_model( face_intrinsic_id_embeds.flatten(0, 1).unsqueeze(0), face_structure_embeds.flatten(0, 1).unsqueeze(0), structure_scale=face_structure_scale, ) # [b, 16, 768] uncond_faceid_prompt_embeds = self.faceid_proj_model( torch.zeros_like(face_intrinsic_id_embeds[:1]), uncond_face_structure_embeds, structure_scale=face_structure_scale, ) # [1, 16, 768] return faceid_prompt_embeds, uncond_faceid_prompt_embeds def generate( self, prompt=None, negative_prompt=None, pil_ip_image=None, cond_faceids=None, face_structure_scale=0.0, seed=-1, guidance_scale=7.5, num_inference_steps=30, zT=None, **kwargs, ): """ Args: prompt: negative_prompt: pil_ip_image: cond_faceids: [ { "refs": [PIL.Image] or PIL.Image, (Optional) "mix_refs": [PIL.Image], (Optional) "mix_scales": [float], }, ... ] face_structure_scale: seed: guidance_scale: num_inference_steps: zT: **kwargs: Returns: """ if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) with torch.inference_mode(): prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt( prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt) num_prompts = prompt_embeds.shape[0] if pil_ip_image is not None: ip_prompt_embeds, uncond_ip_prompt_embeds = self.get_ip_embeds(pil_ip_image) ip_prompt_embeds = ip_prompt_embeds.repeat(num_prompts, 1, 1) uncond_ip_prompt_embeds = uncond_ip_prompt_embeds.repeat(num_prompts, 1, 1) else: ip_prompt_embeds = uncond_ip_prompt_embeds = \ torch.zeros_like(prompt_embeds[:, :1]).repeat(1, self.num_ip_tokens, 1) prompt_embeds = torch.cat([prompt_embeds, ip_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_ip_prompt_embeds], dim=1) if cond_faceids and len(cond_faceids) > 0: all_faceid_prompt_embeds = [] all_uncond_faceid_prompt_embeds = [] for curr_faceid_info in cond_faceids: refs = curr_faceid_info["refs"] faceid_prompt_embeds, uncond_faceid_prompt_embeds = \ self.get_single_faceid_embeds(refs, face_structure_scale) if "mix_refs" in curr_faceid_info: mix_refs = curr_faceid_info["mix_refs"] mix_scales = curr_faceid_info["mix_scales"] master_face_mix_scale = 1.0 - sum(mix_scales) faceid_prompt_embeds = faceid_prompt_embeds * master_face_mix_scale for mix_ref, mix_scale in zip(mix_refs, mix_scales): faceid_mix_prompt_embeds, _ = self.get_single_faceid_embeds(mix_ref, face_structure_scale) faceid_prompt_embeds = faceid_prompt_embeds + faceid_mix_prompt_embeds * mix_scale all_faceid_prompt_embeds.append(faceid_prompt_embeds) all_uncond_faceid_prompt_embeds.append(uncond_faceid_prompt_embeds) faceid_prompt_embeds = torch.cat(all_faceid_prompt_embeds, dim=1) uncond_faceid_prompt_embeds = torch.cat(all_uncond_faceid_prompt_embeds, dim=1) faceid_prompt_embeds = faceid_prompt_embeds.repeat(num_prompts, 1, 1) uncond_faceid_prompt_embeds = uncond_faceid_prompt_embeds.repeat(num_prompts, 1, 1) prompt_embeds = torch.cat([prompt_embeds, faceid_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_faceid_prompt_embeds], dim=1) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None if zT is not None: h_, w_ = kwargs["image"][0].shape[-2:] latents = torch.randn(num_prompts, 4, h_ // 8, w_ // 8, device=self.device, generator=generator, dtype=self.pipe.unet.dtype) latents[0] = zT else: latents = None images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, latents=latents, **kwargs, ).images return images