Full Code of junjiehe96/UniPortrait for AI

main a4deff2b48e3 cached
16 files
113.0 KB
28.5k tokens
99 symbols
1 requests
Download .txt
Repository: junjiehe96/UniPortrait
Branch: main
Commit: a4deff2b48e3
Files: 16
Total size: 113.0 KB

Directory structure:
gitextract_35y_e2yv/

├── .gitignore
├── LICENSE.txt
├── README.md
├── gradio_app.py
├── requirements.txt
└── uniportrait/
    ├── __init__.py
    ├── curricular_face/
    │   ├── __init__.py
    │   ├── backbone/
    │   │   ├── __init__.py
    │   │   ├── common.py
    │   │   ├── model_irse.py
    │   │   └── model_resnet.py
    │   └── inference.py
    ├── inversion.py
    ├── resampler.py
    ├── uniportrait_attention_processor.py
    └── uniportrait_pipeline.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea/
.DS_Store
*.dat
*.mat

training/
lightning_logs/
image_log/

*.png
*.jpg
*.jpeg
*.webp

*.pth
*.pt
*.ckpt
*.safetensors

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: LICENSE.txt
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">
<h1>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization</h1>

<a href='https://aigcdesigngroup.github.io/UniPortrait-Page/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://arxiv.org/abs/2408.05939'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
<a href='https://huggingface.co/spaces/Junjie96/UniPortrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>

</div>

<img src='assets/highlight.png'>

UniPortrait is an innovative human image personalization framework. It customizes single- and multi-ID images in a
unified manner, providing high-fidelity identity preservation, extensive facial editability, free-form text description,
and no requirement for a predetermined layout.

---

## Release

- [2025/05/01] 🔥 We release the code and demo for the `FLUX.1-dev` version of [AnyStory](https://github.com/junjiehe96/AnyStory), a unified approach to general subject personalization.
- [2024/10/18] 🔥 We release the inference code and demo, which has simply
  integrated [ControlNet](https://github.com/lllyasviel/ControlNet)
  , [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter),
  and [StyleAligned](https://github.com/google/style-aligned). The weight for this version is consistent with the
  huggingface space and experiments in the paper. We are now working on generalizing our method to more advanced
  diffusion models and more general custom concepts. Please stay tuned!
- [2024/08/12] 🔥 We release the [technical report](https://arxiv.org/abs/2408.05939)
  , [project page](https://aigcdesigngroup.github.io/UniPortrait-Page/),
  and [HuggingFace demo](https://huggingface.co/spaces/Junjie96/UniPortrait) 🤗!

## Quickstart

```shell
# Clone repository
git clone https://github.com/junjiehe96/UniPortrait.git

# install requirements
cd UniPortrait
pip install -r requirements.txt

# download the models
git lfs install
git clone https://huggingface.co/Junjie96/UniPortrait models
# download ip-adapter models 
# Note: recommend downloading manually. We do not require all IP adapter models.
git clone https://huggingface.co/h94/IP-Adapter models/IP-Adapter

# then you can use the gradio app
python gradio_app.py
```

## Applications

<img src='assets/application.png'>

## **Acknowledgements**

This code is built on some excellent repos, including [diffusers](https://github.com/huggingface/diffusers), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [StyleAligned](https://github.com/google/style-aligned). Highly appreciate their great work!

## Cite

If you find UniPortrait useful for your research and applications, please cite us using this BibTeX:

```bibtex
@article{he2024uniportrait,
    title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},
    author={He, Junjie and Geng, Yifeng and Bo, Liefeng},
    journal={arXiv preprint arXiv:2408.05939},
    year={2024}
}
```

For any question, please feel free to open an issue or contact us via hejunjie1103@gmail.com.


================================================
FILE: gradio_app.py
================================================
import os
from io import BytesIO

import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from diffusers import DDIMScheduler, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline
from insightface.app import FaceAnalysis
from insightface.utils import face_align

from uniportrait import inversion
from uniportrait.uniportrait_attention_processor import attn_args
from uniportrait.uniportrait_pipeline import UniPortraitPipeline

port = 7860

device = "cuda"
torch_dtype = torch.float16

# base
base_model_path = "SG161222/Realistic_Vision_V5.1_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
controlnet_pose_ckpt = "lllyasviel/control_v11p_sd15_openpose"
# specific
image_encoder_path = "models/IP-Adapter/models/image_encoder"
ip_ckpt = "models/IP-Adapter/models/ip-adapter_sd15.bin"
face_backbone_ckpt = "models/glint360k_curricular_face_r101_backbone.bin"
uniportrait_faceid_ckpt = "models/uniportrait-faceid_sd15.bin"
uniportrait_router_ckpt = "models/uniportrait-router_sd15.bin"

# load controlnet
pose_controlnet = ControlNetModel.from_pretrained(controlnet_pose_ckpt, torch_dtype=torch_dtype)

# load SD pipeline
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=[pose_controlnet],
    torch_dtype=torch_dtype,
    scheduler=noise_scheduler,
    vae=vae,
    # feature_extractor=None,
    # safety_checker=None,
)

# load uniportrait pipeline
uniportrait_pipeline = UniPortraitPipeline(pipe, image_encoder_path, ip_ckpt=ip_ckpt,
                                           face_backbone_ckpt=face_backbone_ckpt,
                                           uniportrait_faceid_ckpt=uniportrait_faceid_ckpt,
                                           uniportrait_router_ckpt=uniportrait_router_ckpt,
                                           device=device, torch_dtype=torch_dtype)

# load face detection assets
face_app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=["detection"])
face_app.prepare(ctx_id=0, det_size=(640, 640))


def pad_np_bgr_image(np_image, scale=1.25):
    assert scale >= 1.0, "scale should be >= 1.0"
    pad_scale = scale - 1.0
    h, w = np_image.shape[:2]
    top = bottom = int(h * pad_scale)
    left = right = int(w * pad_scale)
    ret = cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128))
    return ret, (left, top)


def process_faceid_image(pil_faceid_image):
    np_faceid_image = np.array(pil_faceid_image.convert("RGB"))
    img = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)
    faces = face_app.get(img)  # bgr
    if len(faces) == 0:
        # padding, try again
        _h, _w = img.shape[:2]
        _img, left_top_coord = pad_np_bgr_image(img)
        faces = face_app.get(_img)
        if len(faces) == 0:
            gr.Info("Warning: No face detected in the image. Continue processing...")

        min_coord = np.array([0, 0])
        max_coord = np.array([_w, _h])
        sub_coord = np.array([left_top_coord[0], left_top_coord[1]])
        for face in faces:
            face.bbox = np.minimum(np.maximum(face.bbox.reshape(-1, 2) - sub_coord, min_coord), max_coord).reshape(4)
            face.kps = face.kps - sub_coord

    faces = sorted(faces, key=lambda x: abs((x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])), reverse=True)
    faceid_face = faces[0]
    norm_face = face_align.norm_crop(img, landmark=faceid_face.kps, image_size=224)
    pil_faceid_align_image = Image.fromarray(cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB))

    return pil_faceid_align_image


def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_supp_images=None,
                                      pil_faceid_mix_images=None, mix_scales=None):
    pil_faceid_align_images = []
    if pil_faceid_image:
        pil_faceid_align_images.append(process_faceid_image(pil_faceid_image))
    if pil_faceid_supp_images and len(pil_faceid_supp_images) > 0:
        for pil_faceid_supp_image in pil_faceid_supp_images:
            if isinstance(pil_faceid_supp_image, Image.Image):
                pil_faceid_align_images.append(process_faceid_image(pil_faceid_supp_image))
            else:
                pil_faceid_align_images.append(
                    process_faceid_image(Image.open(BytesIO(pil_faceid_supp_image)))
                )

    mix_refs = []
    mix_ref_scales = []
    if pil_faceid_mix_images:
        for pil_faceid_mix_image, mix_scale in zip(pil_faceid_mix_images, mix_scales):
            if pil_faceid_mix_image:
                mix_refs.append(process_faceid_image(pil_faceid_mix_image))
                mix_ref_scales.append(mix_scale)

    single_faceid_cond_kwargs = None
    if len(pil_faceid_align_images) > 0:
        single_faceid_cond_kwargs = {
            "refs": pil_faceid_align_images
        }
        if len(mix_refs) > 0:
            single_faceid_cond_kwargs["mix_refs"] = mix_refs
            single_faceid_cond_kwargs["mix_scales"] = mix_ref_scales

    return single_faceid_cond_kwargs


def text_to_single_id_generation_process(
        pil_faceid_image=None, pil_faceid_supp_images=None,
        pil_faceid_mix_image_1=None, mix_scale_1=0.0,
        pil_faceid_mix_image_2=None, mix_scale_2=0.0,
        faceid_scale=0.0, face_structure_scale=0.0,
        prompt="", negative_prompt="",
        num_samples=1, seed=-1,
        image_resolution="512x512",
        inference_steps=25,
):
    if seed == -1:
        seed = None

    single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,
                                                                  pil_faceid_supp_images,
                                                                  [pil_faceid_mix_image_1, pil_faceid_mix_image_2],
                                                                  [mix_scale_1, mix_scale_2])

    cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []

    # reset attn args
    attn_args.reset()
    # set faceid condition
    attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0  # single-faceid lora
    attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0  # multi-faceid lora
    attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
    attn_args.num_faceids = len(cond_faceids)
    print(attn_args)

    h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
    prompt = [prompt] * num_samples
    negative_prompt = [negative_prompt] * num_samples
    images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
                                           cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
                                           seed=seed, guidance_scale=7.5,
                                           num_inference_steps=inference_steps,
                                           image=[torch.zeros([1, 3, h, w])],
                                           controlnet_conditioning_scale=[0.0])
    final_out = []
    for pil_image in images:
        final_out.append(pil_image)

    for single_faceid_cond_kwargs in cond_faceids:
        final_out.extend(single_faceid_cond_kwargs["refs"])
        if "mix_refs" in single_faceid_cond_kwargs:
            final_out.extend(single_faceid_cond_kwargs["mix_refs"])

    return final_out


def text_to_multi_id_generation_process(
        pil_faceid_image_1=None, pil_faceid_supp_images_1=None,
        pil_faceid_mix_image_1_1=None, mix_scale_1_1=0.0,
        pil_faceid_mix_image_1_2=None, mix_scale_1_2=0.0,
        pil_faceid_image_2=None, pil_faceid_supp_images_2=None,
        pil_faceid_mix_image_2_1=None, mix_scale_2_1=0.0,
        pil_faceid_mix_image_2_2=None, mix_scale_2_2=0.0,
        faceid_scale=0.0, face_structure_scale=0.0,
        prompt="", negative_prompt="",
        num_samples=1, seed=-1,
        image_resolution="512x512",
        inference_steps=25,
):
    if seed == -1:
        seed = None

    faceid_cond_kwargs_1 = prepare_single_faceid_cond_kwargs(pil_faceid_image_1,
                                                             pil_faceid_supp_images_1,
                                                             [pil_faceid_mix_image_1_1,
                                                              pil_faceid_mix_image_1_2],
                                                             [mix_scale_1_1, mix_scale_1_2])
    faceid_cond_kwargs_2 = prepare_single_faceid_cond_kwargs(pil_faceid_image_2,
                                                             pil_faceid_supp_images_2,
                                                             [pil_faceid_mix_image_2_1,
                                                              pil_faceid_mix_image_2_2],
                                                             [mix_scale_2_1, mix_scale_2_2])
    cond_faceids = []
    if faceid_cond_kwargs_1 is not None:
        cond_faceids.append(faceid_cond_kwargs_1)
    if faceid_cond_kwargs_2 is not None:
        cond_faceids.append(faceid_cond_kwargs_2)

    # reset attn args
    attn_args.reset()
    # set faceid condition
    attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0  # single-faceid lora
    attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0  # multi-faceid lora
    attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
    attn_args.num_faceids = len(cond_faceids)
    print(attn_args)

    h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])
    prompt = [prompt] * num_samples
    negative_prompt = [negative_prompt] * num_samples
    images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
                                           cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
                                           seed=seed, guidance_scale=7.5,
                                           num_inference_steps=inference_steps,
                                           image=[torch.zeros([1, 3, h, w])],
                                           controlnet_conditioning_scale=[0.0])

    final_out = []
    for pil_image in images:
        final_out.append(pil_image)

    for single_faceid_cond_kwargs in cond_faceids:
        final_out.extend(single_faceid_cond_kwargs["refs"])
        if "mix_refs" in single_faceid_cond_kwargs:
            final_out.extend(single_faceid_cond_kwargs["mix_refs"])

    return final_out


def image_to_single_id_generation_process(
        pil_faceid_image=None, pil_faceid_supp_images=None,
        pil_faceid_mix_image_1=None, mix_scale_1=0.0,
        pil_faceid_mix_image_2=None, mix_scale_2=0.0,
        faceid_scale=0.0, face_structure_scale=0.0,
        pil_ip_image=None, ip_scale=1.0,
        num_samples=1, seed=-1, image_resolution="768x512",
        inference_steps=25,
):
    if seed == -1:
        seed = None

    single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,
                                                                  pil_faceid_supp_images,
                                                                  [pil_faceid_mix_image_1, pil_faceid_mix_image_2],
                                                                  [mix_scale_1, mix_scale_2])

    cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []

    h, w = int(image_resolution.split("x")[0]), int(image_resolution.split("x")[1])

    # Image Prompt and Style Aligned
    if pil_ip_image is None:
        gr.Error("Please upload a reference image")
    attn_args.reset()
    pil_ip_image = pil_ip_image.convert("RGB").resize((w, h))
    zts = inversion.ddim_inversion(uniportrait_pipeline.pipe, np.array(pil_ip_image), "", inference_steps, 2)
    zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0)

    # reset attn args
    attn_args.reset()
    # set ip condition
    attn_args.ip_scale = ip_scale if pil_ip_image else 0.0
    # set faceid condition
    attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0  # lora for single faceid
    attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0  # lora for >1 faceids
    attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0
    attn_args.num_faceids = len(cond_faceids)
    # set shared self-attn
    attn_args.enable_share_attn = True
    attn_args.shared_score_shift = -0.5
    print(attn_args)

    prompt = [""] * (1 + num_samples)
    negative_prompt = [""] * (1 + num_samples)
    images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,
                                           pil_ip_image=pil_ip_image,
                                           cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,
                                           seed=seed, guidance_scale=7.5,
                                           num_inference_steps=inference_steps,
                                           image=[torch.zeros([1, 3, h, w])],
                                           controlnet_conditioning_scale=[0.0],
                                           zT=zT, callback_on_step_end=inversion_callback)
    images = images[1:]

    final_out = []
    for pil_image in images:
        final_out.append(pil_image)

    for single_faceid_cond_kwargs in cond_faceids:
        final_out.extend(single_faceid_cond_kwargs["refs"])
        if "mix_refs" in single_faceid_cond_kwargs:
            final_out.extend(single_faceid_cond_kwargs["mix_refs"])

    return final_out


def text_to_single_id_generation_block():
    gr.Markdown("## Text-to-Single-ID Generation")
    gr.HTML(text_to_single_id_description)
    gr.HTML(text_to_single_id_tips)
    with gr.Row():
        with gr.Column(scale=1, min_width=100):
            prompt = gr.Textbox(value="", label='Prompt', lines=2)
            negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt')

            run_button = gr.Button(value="Run")
            with gr.Accordion("Options", open=True):
                image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
                                               label="Image Resolution (HxW)")
                seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
                                 value=2147483647)
                num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
                inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)

                faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
                face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0,
                                                 step=0.01, value=0.1)

        with gr.Column(scale=2, min_width=100):
            with gr.Row(equal_height=False):
                pil_faceid_image = gr.Image(type="pil", label="ID Image")
                with gr.Accordion("ID Supplements", open=True):
                    with gr.Row():
                        pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"],
                                                         type="binary", label="Additional ID Images")
                    with gr.Row():
                        with gr.Column(scale=1, min_width=100):
                            pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1")
                            mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
                        with gr.Column(scale=1, min_width=100):
                            pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2")
                            mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0)

            with gr.Row():
                example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
                result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True,
                                            format="png")
    with gr.Row():
        examples = [
            [
                "A young man with short black hair, wearing a black hoodie with a hood, was paired with a blue denim jacket with yellow details.",
                "assets/examples/1-newton.jpg",
                "assets/examples/1-output-1.png",
            ],
        ]
        gr.Examples(
            label="Examples",
            examples=examples,
            fn=lambda x, y, z: (x, y),
            inputs=[prompt, pil_faceid_image, example_output],
            outputs=[prompt, pil_faceid_image]
        )
    ips = [
        pil_faceid_image, pil_faceid_supp_images,
        pil_faceid_mix_image_1, mix_scale_1,
        pil_faceid_mix_image_2, mix_scale_2,
        faceid_scale, face_structure_scale,
        prompt, negative_prompt,
        num_samples, seed,
        image_resolution,
        inference_steps,
    ]
    run_button.click(fn=text_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])


def text_to_multi_id_generation_block():
    gr.Markdown("## Text-to-Multi-ID Generation")
    gr.HTML(text_to_multi_id_description)
    gr.HTML(text_to_multi_id_tips)
    with gr.Row():
        with gr.Column(scale=1, min_width=100):
            prompt = gr.Textbox(value="", label='Prompt', lines=2)
            negative_prompt = gr.Textbox(value="nsfw", label='Negative Prompt')
            run_button = gr.Button(value="Run")
            with gr.Accordion("Options", open=True):
                image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
                                               label="Image Resolution (HxW)")
                seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
                                 value=2147483647)
                num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
                inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)

                faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
                face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0,
                                                 step=0.01, value=0.3)

        with gr.Column(scale=2, min_width=100):
            with gr.Row(equal_height=False):
                with gr.Column(scale=1, min_width=100):
                    pil_faceid_image_1 = gr.Image(type="pil", label="First ID")
                    with gr.Accordion("First ID Supplements", open=False):
                        with gr.Row():
                            pil_faceid_supp_images_1 = gr.File(file_count="multiple", file_types=["image"],
                                                               type="binary", label="Additional ID Images")
                        with gr.Row():
                            with gr.Column(scale=1, min_width=100):
                                pil_faceid_mix_image_1_1 = gr.Image(type="pil", label="Mix ID 1")
                                mix_scale_1_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01,
                                                          value=0.0)
                            with gr.Column(scale=1, min_width=100):
                                pil_faceid_mix_image_1_2 = gr.Image(type="pil", label="Mix ID 2")
                                mix_scale_1_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01,
                                                          value=0.0)
                with gr.Column(scale=1, min_width=100):
                    pil_faceid_image_2 = gr.Image(type="pil", label="Second ID")
                    with gr.Accordion("Second ID Supplements", open=False):
                        with gr.Row():
                            pil_faceid_supp_images_2 = gr.File(file_count="multiple", file_types=["image"],
                                                               type="binary", label="Additional ID Images")
                        with gr.Row():
                            with gr.Column(scale=1, min_width=100):
                                pil_faceid_mix_image_2_1 = gr.Image(type="pil", label="Mix ID 1")
                                mix_scale_2_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01,
                                                          value=0.0)
                            with gr.Column(scale=1, min_width=100):
                                pil_faceid_mix_image_2_2 = gr.Image(type="pil", label="Mix ID 2")
                                mix_scale_2_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01,
                                                          value=0.0)

            with gr.Row():
                example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
                result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4, preview=True,
                                            format="png")
    with gr.Row():
        examples = [
            [
                "The two female models, fair-skinned, wore a white V-neck short-sleeved top with a light smile on the corners of their mouths. The background was off-white.",
                "assets/examples/2-stylegan2-ffhq-0100.png",
                "assets/examples/2-stylegan2-ffhq-0293.png",
                "assets/examples/2-output-1.png",
            ],
        ]
        gr.Examples(
            label="Examples",
            examples=examples,
            inputs=[prompt, pil_faceid_image_1, pil_faceid_image_2, example_output],
        )
    ips = [
        pil_faceid_image_1, pil_faceid_supp_images_1,
        pil_faceid_mix_image_1_1, mix_scale_1_1,
        pil_faceid_mix_image_1_2, mix_scale_1_2,
        pil_faceid_image_2, pil_faceid_supp_images_2,
        pil_faceid_mix_image_2_1, mix_scale_2_1,
        pil_faceid_mix_image_2_2, mix_scale_2_2,
        faceid_scale, face_structure_scale,
        prompt, negative_prompt,
        num_samples, seed,
        image_resolution,
        inference_steps,
    ]
    run_button.click(fn=text_to_multi_id_generation_process, inputs=ips, outputs=[result_gallery])


def image_to_single_id_generation_block():
    gr.Markdown("## Image-to-Single-ID Generation")
    gr.HTML(image_to_single_id_description)
    gr.HTML(image_to_single_id_tips)
    with gr.Row():
        with gr.Column(scale=1, min_width=100):
            run_button = gr.Button(value="Run")
            seed = gr.Slider(label="Seed (-1 indicates random)", minimum=-1, maximum=2147483647, step=1,
                             value=2147483647)
            num_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
            image_resolution = gr.Dropdown(choices=["768x512", "512x512", "512x768"], value="512x512",
                                           label="Image Resolution (HxW)")
            inference_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, visible=False)

            ip_scale = gr.Slider(label="Reference Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
            faceid_scale = gr.Slider(label="Face ID Scale", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
            face_structure_scale = gr.Slider(label="Face Structure Scale", minimum=0.0, maximum=1.0, step=0.01,
                                             value=0.3)

        with gr.Column(scale=3, min_width=100):
            with gr.Row(equal_height=False):
                pil_ip_image = gr.Image(type="pil", label="Portrait Reference")
                pil_faceid_image = gr.Image(type="pil", label="ID Image")
                with gr.Accordion("ID Supplements", open=True):
                    with gr.Row():
                        pil_faceid_supp_images = gr.File(file_count="multiple", file_types=["image"],
                                                         type="binary", label="Additional ID Images")
                    with gr.Row():
                        with gr.Column(scale=1, min_width=100):
                            pil_faceid_mix_image_1 = gr.Image(type="pil", label="Mix ID 1")
                            mix_scale_1 = gr.Slider(label="Mix Scale 1", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
                        with gr.Column(scale=1, min_width=100):
                            pil_faceid_mix_image_2 = gr.Image(type="pil", label="Mix ID 2")
                            mix_scale_2 = gr.Slider(label="Mix Scale 2", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
            with gr.Row():
                with gr.Column(scale=3, min_width=100):
                    example_output = gr.Image(type="pil", label="(Example Output)", visible=False)
                    result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=4,
                                                preview=True, format="png")
    with gr.Row():
        examples = [
            [
                "assets/examples/3-style-1.png",
                "assets/examples/3-stylegan2-ffhq-0293.png",
                0.7,
                0.3,
                "assets/examples/3-output-1.png",
            ],
            [
                "assets/examples/3-style-1.png",
                "assets/examples/3-stylegan2-ffhq-0293.png",
                0.6,
                0.0,
                "assets/examples/3-output-2.png",
            ],
            [
                "assets/examples/3-style-2.jpg",
                "assets/examples/3-stylegan2-ffhq-0381.png",
                0.7,
                0.3,
                "assets/examples/3-output-3.png",
            ],
            [
                "assets/examples/3-style-3.jpg",
                "assets/examples/3-stylegan2-ffhq-0381.png",
                0.6,
                0.0,
                "assets/examples/3-output-4.png",
            ],
        ]
        gr.Examples(
            label="Examples",
            examples=examples,
            fn=lambda x, y, z, w, v: (x, y, z, w),
            inputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale, example_output],
            outputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale]
        )
    ips = [
        pil_faceid_image, pil_faceid_supp_images,
        pil_faceid_mix_image_1, mix_scale_1,
        pil_faceid_mix_image_2, mix_scale_2,
        faceid_scale, face_structure_scale,
        pil_ip_image, ip_scale,
        num_samples, seed, image_resolution,
        inference_steps,
    ]
    run_button.click(fn=image_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])


if __name__ == "__main__":
    os.environ["no_proxy"] = "localhost,127.0.0.1,::1"

    title = r"""
            <div style="text-align: center;">
                <h1> UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization </h1>
                <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
                    <a href="https://arxiv.org/pdf/2408.05939"><img src="https://img.shields.io/badge/arXiv-2408.05939-red"></a>
                    &nbsp;
                    <a href='https://aigcdesigngroup.github.io/UniPortrait-Page/'><img src='https://img.shields.io/badge/Project_Page-UniPortrait-green' alt='Project Page'></a>
                    &nbsp;
                    <a href="https://github.com/junjiehe96/UniPortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
                </div>
                </br>
            </div>
        """

    title_description = r"""
        This is the <b>official 🤗 Gradio demo</b> for <a href='https://arxiv.org/pdf/2408.05939' target='_blank'><b>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization</b></a>.<br>
        The demo provides three capabilities: text-to-single-ID personalization, text-to-multi-ID personalization, and image-to-single-ID personalization. All of these are based on the <b>Stable Diffusion v1-5</b> model. Feel free to give them a try! 😊
        """

    text_to_single_id_description = r"""🚀🚀🚀Quick start:<br>
        1. Enter a text prompt (Chinese or English), Upload an image with a face, and Click the <b>Run</b> button. 🤗<br>
        """

    text_to_single_id_tips = r"""💡💡💡Tips:<br>
        1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>
        2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".<br>
        3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).<br>
        """

    text_to_multi_id_description = r"""🚀🚀🚀Quick start:<br>
        1. Enter a text prompt (Chinese or English), Upload an image with a face in "First ID" and "Second ID" blocks respectively, and Click the <b>Run</b> button. 🤗<br>
        """

    text_to_multi_id_tips = r"""💡💡💡Tips:<br>
        1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>
        2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the "ID supplements".<br>
        3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the ID and text alignment. We recommend using "Face ID Scale" (0.3~0.7) and "Face Structure Scale" (0.0~0.4).<br>
        """

    image_to_single_id_description = r"""🚀🚀🚀Quick start: Upload an image as the portrait reference (can be any style), Upload a face image, and Click the <b>Run</b> button. 🤗<br>"""

    image_to_single_id_tips = r"""💡💡💡Tips:<br>
        1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>
        2. It's a good idea to upload multiple reference photos of your face to improve ID consistency. Additional references can be uploaded in the "ID supplements".<br>
        3. The appropriate values of "Face ID Scale" and "Face Structure Scale" are important for balancing the portrait reference and ID alignment. We recommend using "Face ID Scale" (0.5~0.7) and "Face Structure Scale" (0.0~0.4).<br>
        """

    citation = r"""
        ---
        📝 **Citation**
        <br>
        If our work is helpful for your research or applications, please cite us via:
        ```bibtex
        @article{he2024uniportrait,
          title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},
          author={He, Junjie and Geng, Yifeng and Bo, Liefeng},
          journal={arXiv preprint arXiv:2408.05939},
          year={2024}
        }
        ```
        📧 **Contact**
        <br>
        If you have any questions, please feel free to open an issue or directly reach us out at <b>hejunjie1103@gmail.com</b>.
        """

    block = gr.Blocks(title="UniPortrait").queue()
    with block:
        gr.HTML(title)
        gr.HTML(title_description)

        with gr.TabItem("Text-to-Single-ID"):
            text_to_single_id_generation_block()

        with gr.TabItem("Text-to-Multi-ID"):
            text_to_multi_id_generation_block()

        with gr.TabItem("Image-to-Single-ID (Stylization)"):
            image_to_single_id_generation_block()

        gr.Markdown(citation)

    block.launch(server_name='0.0.0.0', share=False, server_port=port, allowed_paths=["/"])


================================================
FILE: requirements.txt
================================================
diffusers
gradio
onnxruntime-gpu
insightface
torch
tqdm
transformers


================================================
FILE: uniportrait/__init__.py
================================================


================================================
FILE: uniportrait/curricular_face/__init__.py
================================================


================================================
FILE: uniportrait/curricular_face/backbone/__init__.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone
from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50,
                         IR_SE_101, IR_SE_152, IR_SE_200)
from .model_resnet import ResNet_50, ResNet_101, ResNet_152

_model_dict = {
    'ResNet_50': ResNet_50,
    'ResNet_101': ResNet_101,
    'ResNet_152': ResNet_152,
    'IR_18': IR_18,
    'IR_34': IR_34,
    'IR_50': IR_50,
    'IR_101': IR_101,
    'IR_152': IR_152,
    'IR_200': IR_200,
    'IR_SE_50': IR_SE_50,
    'IR_SE_101': IR_SE_101,
    'IR_SE_152': IR_SE_152,
    'IR_SE_200': IR_SE_200
}


def get_model(key):
    """ Get different backbone network by key,
        support ResNet50, ResNet_101, ResNet_152
        IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,
        IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.
    """
    if key in _model_dict.keys():
        return _model_dict[key]
    else:
        raise KeyError('not support model {}'.format(key))


================================================
FILE: uniportrait/curricular_face/backbone/common.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py
import torch.nn as nn
from torch.nn import (Conv2d, Module, ReLU,
                      Sigmoid)


def initialize_weights(modules):
    """ Weight initilize, conv2d and linear is initialized with kaiming_normal
    """
    for m in modules:
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(
                m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                m.bias.data.zero_()


class Flatten(Module):
    """ Flat tensor
    """

    def forward(self, input):
        return input.view(input.size(0), -1)


class SEModule(Module):
    """ SE block
    """

    def __init__(self, channels, reduction):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = Conv2d(
            channels,
            channels // reduction,
            kernel_size=1,
            padding=0,
            bias=False)

        nn.init.xavier_uniform_(self.fc1.weight.data)

        self.relu = ReLU(inplace=True)
        self.fc2 = Conv2d(
            channels // reduction,
            channels,
            kernel_size=1,
            padding=0,
            bias=False)

        self.sigmoid = Sigmoid()

    def forward(self, x):
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)

        return module_input * x


================================================
FILE: uniportrait/curricular_face/backbone/model_irse.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py
from collections import namedtuple

from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
                      MaxPool2d, Module, PReLU, Sequential)

from .common import Flatten, SEModule, initialize_weights


class BasicBlockIR(Module):
    """ BasicBlock for IRNet
    """

    def __init__(self, in_channel, depth, stride):
        super(BasicBlockIR, self).__init__()
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth))
        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
            BatchNorm2d(depth), PReLU(depth),
            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
            BatchNorm2d(depth))

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)

        return res + shortcut


class BottleneckIR(Module):
    """ BasicBlock with bottleneck for IRNet
    """

    def __init__(self, in_channel, depth, stride):
        super(BottleneckIR, self).__init__()
        reduction_channel = depth // 4
        if in_channel == depth:
            self.shortcut_layer = MaxPool2d(1, stride)
        else:
            self.shortcut_layer = Sequential(
                Conv2d(in_channel, depth, (1, 1), stride, bias=False),
                BatchNorm2d(depth))
        self.res_layer = Sequential(
            BatchNorm2d(in_channel),
            Conv2d(
                in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
            BatchNorm2d(reduction_channel), PReLU(reduction_channel),
            Conv2d(
                reduction_channel,
                reduction_channel, (3, 3), (1, 1),
                1,
                bias=False), BatchNorm2d(reduction_channel),
            PReLU(reduction_channel),
            Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
            BatchNorm2d(depth))

    def forward(self, x):
        shortcut = self.shortcut_layer(x)
        res = self.res_layer(x)

        return res + shortcut


class BasicBlockIRSE(BasicBlockIR):

    def __init__(self, in_channel, depth, stride):
        super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
        self.res_layer.add_module('se_block', SEModule(depth, 16))


class BottleneckIRSE(BottleneckIR):

    def __init__(self, in_channel, depth, stride):
        super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
        self.res_layer.add_module('se_block', SEModule(depth, 16))


class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
    '''A named tuple describing a ResNet block.'''


def get_block(in_channel, depth, num_units, stride=2):
    return [Bottleneck(in_channel, depth, stride)] + \
           [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]


def get_blocks(num_layers):
    if num_layers == 18:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=2),
            get_block(in_channel=64, depth=128, num_units=2),
            get_block(in_channel=128, depth=256, num_units=2),
            get_block(in_channel=256, depth=512, num_units=2)
        ]
    elif num_layers == 34:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=3),
            get_block(in_channel=64, depth=128, num_units=4),
            get_block(in_channel=128, depth=256, num_units=6),
            get_block(in_channel=256, depth=512, num_units=3)
        ]
    elif num_layers == 50:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=3),
            get_block(in_channel=64, depth=128, num_units=4),
            get_block(in_channel=128, depth=256, num_units=14),
            get_block(in_channel=256, depth=512, num_units=3)
        ]
    elif num_layers == 100:
        blocks = [
            get_block(in_channel=64, depth=64, num_units=3),
            get_block(in_channel=64, depth=128, num_units=13),
            get_block(in_channel=128, depth=256, num_units=30),
            get_block(in_channel=256, depth=512, num_units=3)
        ]
    elif num_layers == 152:
        blocks = [
            get_block(in_channel=64, depth=256, num_units=3),
            get_block(in_channel=256, depth=512, num_units=8),
            get_block(in_channel=512, depth=1024, num_units=36),
            get_block(in_channel=1024, depth=2048, num_units=3)
        ]
    elif num_layers == 200:
        blocks = [
            get_block(in_channel=64, depth=256, num_units=3),
            get_block(in_channel=256, depth=512, num_units=24),
            get_block(in_channel=512, depth=1024, num_units=36),
            get_block(in_channel=1024, depth=2048, num_units=3)
        ]

    return blocks


class Backbone(Module):

    def __init__(self, input_size, num_layers, mode='ir'):
        """ Args:
            input_size: input_size of backbone
            num_layers: num_layers of backbone
            mode: support ir or irse
        """
        super(Backbone, self).__init__()
        assert input_size[0] in [112, 224], \
            'input_size should be [112, 112] or [224, 224]'
        assert num_layers in [18, 34, 50, 100, 152, 200], \
            'num_layers should be 18, 34, 50, 100 or 152'
        assert mode in ['ir', 'ir_se'], \
            'mode should be ir or ir_se'
        self.input_layer = Sequential(
            Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
            PReLU(64))
        blocks = get_blocks(num_layers)
        if num_layers <= 100:
            if mode == 'ir':
                unit_module = BasicBlockIR
            elif mode == 'ir_se':
                unit_module = BasicBlockIRSE
            output_channel = 512
        else:
            if mode == 'ir':
                unit_module = BottleneckIR
            elif mode == 'ir_se':
                unit_module = BottleneckIRSE
            output_channel = 2048

        if input_size[0] == 112:
            self.output_layer = Sequential(
                BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
                Linear(output_channel * 7 * 7, 512),
                BatchNorm1d(512, affine=False))
        else:
            self.output_layer = Sequential(
                BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
                Linear(output_channel * 14 * 14, 512),
                BatchNorm1d(512, affine=False))

        modules = []
        mid_layer_indices = []  # [2, 15, 45, 48], total 49 layers for IR101
        for block in blocks:
            if len(mid_layer_indices) == 0:
                mid_layer_indices.append(len(block) - 1)
            else:
                mid_layer_indices.append(len(block) + mid_layer_indices[-1])
            for bottleneck in block:
                modules.append(
                    unit_module(bottleneck.in_channel, bottleneck.depth,
                                bottleneck.stride))
        self.body = Sequential(*modules)
        self.mid_layer_indices = mid_layer_indices[-4:]

        initialize_weights(self.modules())

    def forward(self, x, return_mid_feats=False):
        x = self.input_layer(x)
        if not return_mid_feats:
            x = self.body(x)
            x = self.output_layer(x)
            return x
        else:
            out_feats = []
            for idx, module in enumerate(self.body):
                x = module(x)
                if idx in self.mid_layer_indices:
                    out_feats.append(x)
            x = self.output_layer(x)
            return x, out_feats


def IR_18(input_size):
    """ Constructs a ir-18 model.
    """
    model = Backbone(input_size, 18, 'ir')

    return model


def IR_34(input_size):
    """ Constructs a ir-34 model.
    """
    model = Backbone(input_size, 34, 'ir')

    return model


def IR_50(input_size):
    """ Constructs a ir-50 model.
    """
    model = Backbone(input_size, 50, 'ir')

    return model


def IR_101(input_size):
    """ Constructs a ir-101 model.
    """
    model = Backbone(input_size, 100, 'ir')

    return model


def IR_152(input_size):
    """ Constructs a ir-152 model.
    """
    model = Backbone(input_size, 152, 'ir')

    return model


def IR_200(input_size):
    """ Constructs a ir-200 model.
    """
    model = Backbone(input_size, 200, 'ir')

    return model


def IR_SE_50(input_size):
    """ Constructs a ir_se-50 model.
    """
    model = Backbone(input_size, 50, 'ir_se')

    return model


def IR_SE_101(input_size):
    """ Constructs a ir_se-101 model.
    """
    model = Backbone(input_size, 100, 'ir_se')

    return model


def IR_SE_152(input_size):
    """ Constructs a ir_se-152 model.
    """
    model = Backbone(input_size, 152, 'ir_se')

    return model


def IR_SE_200(input_size):
    """ Constructs a ir_se-200 model.
    """
    model = Backbone(input_size, 200, 'ir_se')

    return model


================================================
FILE: uniportrait/curricular_face/backbone/model_resnet.py
================================================
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py
import torch.nn as nn
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
                      MaxPool2d, Module, ReLU, Sequential)

from .common import initialize_weights


def conv3x3(in_planes, out_planes, stride=1):
    """ 3x3 convolution with padding
    """
    return Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """ 1x1 convolution
    """
    return Conv2d(
        in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = BatchNorm2d(planes * self.expansion)
        self.relu = ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(Module):
    """ ResNet backbone
    """

    def __init__(self, input_size, block, layers, zero_init_residual=True):
        """ Args:
            input_size: input_size of backbone
            block: block function
            layers: layers in each block
        """
        super(ResNet, self).__init__()
        assert input_size[0] in [112, 224], \
            'input_size should be [112, 112] or [224, 224]'
        self.inplanes = 64
        self.conv1 = Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = BatchNorm2d(64)
        self.relu = ReLU(inplace=True)
        self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.bn_o1 = BatchNorm2d(2048)
        self.dropout = Dropout()
        if input_size[0] == 112:
            self.fc = Linear(2048 * 4 * 4, 512)
        else:
            self.fc = Linear(2048 * 7 * 7, 512)
        self.bn_o2 = BatchNorm1d(512)

        initialize_weights(self.modules)
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn_o1(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bn_o2(x)

        return x


def ResNet_50(input_size, **kwargs):
    """ Constructs a ResNet-50 model.
    """
    model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)

    return model


def ResNet_101(input_size, **kwargs):
    """ Constructs a ResNet-101 model.
    """
    model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)

    return model


def ResNet_152(input_size, **kwargs):
    """ Constructs a ResNet-152 model.
    """
    model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)

    return model


================================================
FILE: uniportrait/curricular_face/inference.py
================================================
import glob
import os

import cv2
import numpy as np
import torch
from tqdm.auto import tqdm

from .backbone import get_model


@torch.no_grad()
def inference(name, weight, src_norm_dir):
    face_model = get_model(name)([112, 112])
    face_model.load_state_dict(torch.load(weight, map_location="cpu"))
    face_model = face_model.to("cpu")
    face_model.eval()

    id2src_norm = {}
    for src_id in sorted(list(os.listdir(src_norm_dir))):
        id2src_norm[src_id] = sorted(list(glob.glob(f"{os.path.join(src_norm_dir, src_id)}/*")))

    total_sims = []
    for id_name in tqdm(id2src_norm):
        src_face_embeddings = []
        for src_img_path in id2src_norm[id_name]:
            src_img = cv2.imread(src_img_path)
            src_img = cv2.resize(src_img, (112, 112))
            src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
            src_img = np.transpose(src_img, (2, 0, 1))
            src_img = torch.from_numpy(src_img).unsqueeze(0).float()
            src_img.div_(255).sub_(0.5).div_(0.5)
            embedding = face_model(src_img).detach().cpu().numpy()[0]
            embedding = embedding / np.linalg.norm(embedding)
            src_face_embeddings.append(embedding)  # 512

        num = len(src_face_embeddings)
        src_face_embeddings = np.stack(src_face_embeddings)  # n, 512
        sim = src_face_embeddings @ src_face_embeddings.T  # n, n
        mean_sim = (np.sum(sim) - num * 1.0) / ((num - 1) * num)
        print(f"{id_name}: {mean_sim}")
        total_sims.append(mean_sim)

    return np.mean(total_sims)


if __name__ == "__main__":
    name = 'IR_101'
    weight = "models/glint360k_curricular_face_r101_backbone.bin"
    src_norm_dir = "/disk1/hejunjie.hjj/data/normface-AFD-id-20"
    mean_sim = inference(name, weight, src_norm_dir)
    print(f"total: {mean_sim:.4f}")  # total: 0.6299


================================================
FILE: uniportrait/inversion.py
================================================
# modified from https://github.com/google/style-aligned/blob/main/inversion.py

from __future__ import annotations

from typing import Callable

import numpy as np
import torch
from diffusers import StableDiffusionPipeline
from tqdm import tqdm

T = torch.Tensor
InversionCallback = Callable[[StableDiffusionPipeline, int, T, dict[str, T]], dict[str, T]]


def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: str) -> tuple[dict[str, T], T]:
    device = model._execution_device
    prompt_embeds = model._encode_prompt(
        prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True,
        negative_prompt="")
    return prompt_embeds


def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T:
    model.vae.to(dtype=torch.float32)
    image = torch.from_numpy(image).float() / 255.
    image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
    latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor
    model.vae.to(dtype=torch.float16)
    return latent


def _next_step(model: StableDiffusionPipeline, model_output: T, timestep: int, sample: T) -> T:
    timestep, next_timestep = min(
        timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
    alpha_prod_t = model.scheduler.alphas_cumprod[
        int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
    alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
    beta_prod_t = 1 - alpha_prod_t
    next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
    return next_sample


def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, context: T, guidance_scale: float):
    latents_input = torch.cat([latent] * 2)
    noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
    noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
    # latents = next_step(model, noise_pred, t, latent)
    return noise_pred


def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scale) -> T:
    all_latent = [z0]
    text_embedding = _encode_text_with_negative(model, prompt)
    image_embedding = torch.zeros_like(text_embedding[:, :1]).repeat(1, 4, 1)  # for ip embedding
    text_embedding = torch.cat([text_embedding, image_embedding], dim=1)
    latent = z0.clone().detach().half()
    for i in tqdm(range(model.scheduler.num_inference_steps)):
        t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
        noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale)
        latent = _next_step(model, noise_pred, t, latent)
        all_latent.append(latent)
    return torch.cat(all_latent).flip(0)


def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:
    def callback_on_step_end(pipeline: StableDiffusionPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[
        str, T]:
        latents = callback_kwargs['latents']
        latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
        return {'latents': latents}

    return zts[offset], callback_on_step_end


@torch.no_grad()
def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int,
                   guidance_scale, ) -> T:
    z0 = _encode_image(model, x0)
    model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
    zs = _ddim_loop(model, z0, prompt, guidance_scale)
    return zs


================================================
FILE: uniportrait/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py

import math

import torch
import torch.nn as nn


# FFN
def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


def reshape_tensor(x, heads):
    bs, length, width = x.shape
    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
    x = x.view(bs, length, heads, -1)
    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
    x = x.transpose(1, 2)
    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
    x = x.reshape(bs, heads, length, -1)
    return x


class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents, attention_mask=None):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latents (torch.Tensor): latent features
                shape (b, n2, D)
            attention_mask (torch.Tensor): attention mask
                shape (b, n1, 1)
        """
        x = self.norm1(x)
        latents = self.norm2(latents)

        b, l, _ = latents.shape

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q = reshape_tensor(q, self.heads)
        k = reshape_tensor(k, self.heads)
        v = reshape_tensor(v, self.heads)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
        if attention_mask is not None:
            attention_mask = attention_mask.transpose(1, 2)  # (b, 1, n1)
            attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :, :1]).repeat(1, 1, l)],
                                       dim=2)  # b, 1, n1+n2
            attention_mask = (attention_mask - 1.) * 100.  # 0 means kept and -100 means dropped
            attention_mask = attention_mask.unsqueeze(1)
            weight = weight + attention_mask  # b, h, n2, n1+n2

        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v

        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)

        return self.to_out(out)


class UniPortraitFaceIDResampler(torch.nn.Module):
    def __init__(
            self,
            intrinsic_id_embedding_dim=512,
            structure_embedding_dim=64 + 128 + 256 + 1280,
            num_tokens=16,
            depth=6,
            dim=768,
            dim_head=64,
            heads=12,
            ff_mult=4,
            output_dim=768,
    ):
        super().__init__()

        self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5)

        self.proj_id = torch.nn.Sequential(
            torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2),
            torch.nn.GELU(),
            torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim),
        )
        self.proj_clip = torch.nn.Sequential(
            torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2),
            torch.nn.GELU(),
            torch.nn.Linear(structure_embedding_dim * 2, dim),
        )

        self.layers = torch.nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                torch.nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

        self.proj_out = torch.nn.Linear(dim, output_dim)
        self.norm_out = torch.nn.LayerNorm(output_dim)

    def forward(
            self,
            intrinsic_id_embeds,
            structure_embeds,
            structure_scale=1.0,
            intrinsic_id_attention_mask=None,
            structure_attention_mask=None
    ):

        latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1)

        intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds)
        structure_embeds = self.proj_clip(structure_embeds)

        for attn1, attn2, ff in self.layers:
            latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents
            latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


================================================
FILE: uniportrait/uniportrait_attention_processor.py
================================================
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.lora import LoRALinearLayer


class AttentionArgs(object):
    def __init__(self) -> None:
        # ip condition
        self.ip_scale = 0.0
        self.ip_mask = None  # ip attention mask

        # faceid condition
        self.lora_scale = 0.0  # lora for single faceid
        self.multi_id_lora_scale = 0.0  # lora for multiple faceids
        self.faceid_scale = 0.0
        self.num_faceids = 0
        self.faceid_mask = None  # faceid attention mask; if not None, it will override the routing map

        # style aligned
        self.enable_share_attn: bool = False
        self.adain_queries_and_keys: bool = False
        self.shared_score_scale: float = 1.0
        self.shared_score_shift: float = 0.0

    def reset(self):
        # ip condition
        self.ip_scale = 0.0
        self.ip_mask = None  # ip attention mask

        # faceid condition
        self.lora_scale = 0.0  # lora for single faceid
        self.multi_id_lora_scale = 0.0  # lora for multiple faceids
        self.faceid_scale = 0.0
        self.num_faceids = 0
        self.faceid_mask = None  # faceid attention mask; if not None, it will override the routing map

        # style aligned
        self.enable_share_attn: bool = False
        self.adain_queries_and_keys: bool = False
        self.shared_score_scale: float = 1.0
        self.shared_score_shift: float = 0.0

    def __repr__(self):
        indent_str = '    '
        s = f",\n{indent_str}".join(f"{attr}={value}" for attr, value in vars(self).items())
        return self.__class__.__name__ + '(' + f'\n{indent_str}' + s + ')'


attn_args = AttentionArgs()


def expand_first(feat, scale=1., ):
    b = feat.shape[0]
    feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
    if scale == 1:
        feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
    else:
        feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
        feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
    return feat_style.reshape(*feat.shape)


def concat_first(feat, dim=2, scale=1.):
    feat_style = expand_first(feat, scale=scale)
    return torch.cat((feat, feat_style), dim=dim)


def calc_mean_std(feat, eps: float = 1e-5):
    feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
    feat_mean = feat.mean(dim=-2, keepdims=True)
    return feat_mean, feat_std


def adain(feat):
    feat_mean, feat_std = calc_mean_std(feat)
    feat_style_mean = expand_first(feat_mean)
    feat_style_std = expand_first(feat_std)
    feat = (feat - feat_mean) / feat_std
    feat = feat * feat_style_std + feat_style_mean
    return feat


class UniPortraitLoRAAttnProcessor2_0(nn.Module):

    def __init__(
            self,
            hidden_size=None,
            cross_attention_dim=None,
            rank=128,
            network_alpha=None,
    ):
        super().__init__()

        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

        self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
        self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
            *args,
            **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        if attn_args.lora_scale > 0.0:
            query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)
            key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)
            value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)
        elif attn_args.multi_id_lora_scale > 0.0:
            query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)
            key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)
            value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn_args.enable_share_attn:
            if attn_args.adain_queries_and_keys:
                query = adain(query)
                key = adain(key)
            key = concat_first(key, -2, scale=attn_args.shared_score_scale)
            value = concat_first(value, -2)
            if attn_args.shared_score_shift != 0:
                attention_mask = torch.zeros_like(key[:, :, :, :1]).transpose(-1, -2)  # b, h, 1, k
                attention_mask[:, :, :, query.shape[2]:] += attn_args.shared_score_shift
                hidden_states = F.scaled_dot_product_attention(
                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
                )
            else:
                hidden_states = F.scaled_dot_product_attention(
                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
                )
        else:
            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
            )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        output_hidden_states = attn.to_out[0](hidden_states)
        if attn_args.lora_scale > 0.0:
            output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)
        elif attn_args.multi_id_lora_scale > 0.0:
            output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(
                hidden_states)
        hidden_states = output_hidden_states

        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class UniPortraitLoRAIPAttnProcessor2_0(nn.Module):

    def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None,
                 num_ip_tokens=4, num_faceid_tokens=16):
        super().__init__()

        self.num_ip_tokens = num_ip_tokens
        self.num_faceid_tokens = num_faceid_tokens

        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

        self.to_k_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

        self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
        self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

        self.to_q_router = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.GELU(),
            nn.Linear(hidden_size * 2, hidden_size, bias=False),
        )
        self.to_k_router = nn.Sequential(
            nn.Linear(cross_attention_dim or hidden_size, (cross_attention_dim or hidden_size) * 2),
            nn.GELU(),
            nn.Linear((cross_attention_dim or hidden_size) * 2, hidden_size, bias=False),
        )
        self.aggr_router = nn.Linear(num_faceid_tokens, 1)

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
            *args,
            **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # split hidden states
            faceid_end = encoder_hidden_states.shape[1]
            ip_end = faceid_end - self.num_faceid_tokens * attn_args.num_faceids
            text_end = ip_end - self.num_ip_tokens

            prompt_hidden_states = encoder_hidden_states[:, :text_end]
            ip_hidden_states = encoder_hidden_states[:, text_end: ip_end]
            faceid_hidden_states = encoder_hidden_states[:, ip_end: faceid_end]

            encoder_hidden_states = prompt_hidden_states
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # for router
        if attn_args.num_faceids > 1:
            router_query = self.to_q_router(hidden_states)  # bs, s*s, dim
            router_hidden_states = faceid_hidden_states.reshape(batch_size, attn_args.num_faceids,
                                                                self.num_faceid_tokens, -1)  # bs, num, id_tokens, d
            router_hidden_states = self.aggr_router(router_hidden_states.transpose(-1, -2)).squeeze(-1)  # bs, num, d
            router_key = self.to_k_router(router_hidden_states)  # bs, num, dim
            router_logits = torch.bmm(router_query, router_key.transpose(-1, -2))  # bs, s*s, num
            index = router_logits.max(dim=-1, keepdim=True)[1]
            routing_map = torch.zeros_like(router_logits).scatter_(-1, index, 1.0)
            routing_map = routing_map.transpose(1, 2).unsqueeze(-1)  # bs, num, s*s, 1
        else:
            routing_map = hidden_states.new_ones(size=(1, 1, hidden_states.shape[1], 1))

        # for text
        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        if attn_args.lora_scale > 0.0:
            query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)
            key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)
            value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)
        elif attn_args.multi_id_lora_scale > 0.0:
            query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)
            key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)
            value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # for ip-adapter
        if attn_args.ip_scale > 0.0:
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)

            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

            ip_hidden_states = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale
            )
            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states = ip_hidden_states.to(query.dtype)

            if attn_args.ip_mask is not None:
                ip_mask = attn_args.ip_mask
                h, w = ip_mask.shape[-2:]
                ratio = (h * w / query.shape[2]) ** 0.5
                ip_mask = torch.nn.functional.interpolate(ip_mask, scale_factor=1 / ratio,
                                                          mode='nearest').reshape(
                    [1, -1, 1])
                ip_hidden_states = ip_hidden_states * ip_mask

            if attn_args.enable_share_attn:
                ip_hidden_states[0] = 0.
                ip_hidden_states[batch_size // 2] = 0.
        else:
            ip_hidden_states = torch.zeros_like(hidden_states)

        # for faceid-adapter
        if attn_args.faceid_scale > 0.0:
            faceid_key = self.to_k_faceid(faceid_hidden_states)
            faceid_value = self.to_v_faceid(faceid_hidden_states)

            faceid_query = query[:, None].expand(-1, attn_args.num_faceids, -1, -1,
                                                 -1)  # 2*bs, num, heads, s*s, dim/heads
            faceid_key = faceid_key.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,
                                         head_dim).transpose(2, 3)
            faceid_value = faceid_value.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,
                                             head_dim).transpose(2, 3)

            faceid_hidden_states = F.scaled_dot_product_attention(
                faceid_query, faceid_key, faceid_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale
            )  # 2*bs, num, heads, s*s, dim/heads

            faceid_hidden_states = faceid_hidden_states.transpose(2, 3).reshape(batch_size, attn_args.num_faceids, -1,
                                                                                attn.heads * head_dim)
            faceid_hidden_states = faceid_hidden_states.to(query.dtype)  # 2*bs, num, s*s, dim

            if attn_args.faceid_mask is not None:
                faceid_mask = attn_args.faceid_mask  # 1, num, h, w
                h, w = faceid_mask.shape[-2:]
                ratio = (h * w / query.shape[2]) ** 0.5
                faceid_mask = F.interpolate(faceid_mask, scale_factor=1 / ratio,
                                            mode='bilinear').flatten(2).unsqueeze(-1)  # 1, num, s*s, 1
                faceid_mask = faceid_mask / faceid_mask.sum(1, keepdim=True).clip(min=1e-3)  # 1, num, s*s, 1
                faceid_hidden_states = (faceid_mask * faceid_hidden_states).sum(1)  # 2*bs, s*s, dim
            else:
                faceid_hidden_states = (routing_map * faceid_hidden_states).sum(1)  # 2*bs, s*s, dim

            if attn_args.enable_share_attn:
                faceid_hidden_states[0] = 0.
                faceid_hidden_states[batch_size // 2] = 0.
        else:
            faceid_hidden_states = torch.zeros_like(hidden_states)

        hidden_states = hidden_states + \
                        attn_args.ip_scale * ip_hidden_states + \
                        attn_args.faceid_scale * faceid_hidden_states

        # linear proj
        output_hidden_states = attn.to_out[0](hidden_states)
        if attn_args.lora_scale > 0.0:
            output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)
        elif attn_args.multi_id_lora_scale > 0.0:
            output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(
                hidden_states)
        hidden_states = output_hidden_states

        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


# for controlnet
class UniPortraitCNAttnProcessor2_0:
    def __init__(self, num_ip_tokens=4, num_faceid_tokens=16):

        self.num_ip_tokens = num_ip_tokens
        self.num_faceid_tokens = num_faceid_tokens

    def __call__(
            self,
            attn,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            temb=None,
            *args,
            **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            text_end = encoder_hidden_states.shape[1] - self.num_faceid_tokens * attn_args.num_faceids \
                       - self.num_ip_tokens
            encoder_hidden_states = encoder_hidden_states[:, :text_end]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


================================================
FILE: uniportrait/uniportrait_pipeline.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import ControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from .curricular_face.backbone import get_model
from .resampler import UniPortraitFaceIDResampler
from .uniportrait_attention_processor import UniPortraitCNAttnProcessor2_0 as UniPortraitCNAttnProcessor
from .uniportrait_attention_processor import UniPortraitLoRAAttnProcessor2_0 as UniPortraitLoRAAttnProcessor
from .uniportrait_attention_processor import UniPortraitLoRAIPAttnProcessor2_0 as UniPortraitLoRAIPAttnProcessor


class ImageProjModel(nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds  # b, c
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens,
                                                              self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class UniPortraitPipeline:

    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_backbone_ckpt=None, uniportrait_faceid_ckpt=None,
                 uniportrait_router_ckpt=None, num_ip_tokens=4, num_faceid_tokens=16,
                 lora_rank=128, device=torch.device("cuda"), torch_dtype=torch.float16):

        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.uniportrait_faceid_ckpt = uniportrait_faceid_ckpt
        self.uniportrait_router_ckpt = uniportrait_router_ckpt

        self.num_ip_tokens = num_ip_tokens
        self.num_faceid_tokens = num_faceid_tokens
        self.lora_rank = lora_rank

        self.device = device
        self.torch_dtype = torch_dtype

        self.pipe = sd_pipe.to(self.device)

        # load clip image encoder
        self.clip_image_processor = CLIPImageProcessor(size={"shortest_edge": 224}, do_center_crop=False,
                                                       use_square_size=True)
        self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
            self.device, dtype=self.torch_dtype)
        # load face backbone
        self.facerecog_model = get_model("IR_101")([112, 112])
        self.facerecog_model.load_state_dict(torch.load(face_backbone_ckpt, map_location="cpu"))
        self.facerecog_model = self.facerecog_model.to(self.device, dtype=torch_dtype)
        self.facerecog_model.eval()
        # image proj model
        self.image_proj_model = self.init_image_proj()
        # faceid proj model
        self.faceid_proj_model = self.init_faceid_proj()
        # set uniportrait and ip adapter
        self.set_uniportrait_and_ip_adapter()
        # load uniportrait and ip adapter
        self.load_uniportrait_and_ip_adapter()

    def init_image_proj(self):
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=self.clip_image_encoder.config.projection_dim,
            clip_extra_context_tokens=self.num_ip_tokens,
        ).to(self.device, dtype=self.torch_dtype)
        return image_proj_model

    def init_faceid_proj(self):
        faceid_proj_model = UniPortraitFaceIDResampler(
            intrinsic_id_embedding_dim=512,
            structure_embedding_dim=64 + 128 + 256 + self.clip_image_encoder.config.hidden_size,
            num_tokens=16, depth=6,
            dim=self.pipe.unet.config.cross_attention_dim, dim_head=64,
            heads=12, ff_mult=4,
            output_dim=self.pipe.unet.config.cross_attention_dim
        ).to(self.device, dtype=self.torch_dtype)
        return faceid_proj_model

    def set_uniportrait_and_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            if cross_attention_dim is None:
                attn_procs[name] = UniPortraitLoRAAttnProcessor(
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                    rank=self.lora_rank,
                ).to(self.device, dtype=self.torch_dtype).eval()
            else:
                attn_procs[name] = UniPortraitLoRAIPAttnProcessor(
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                    rank=self.lora_rank,
                    num_ip_tokens=self.num_ip_tokens,
                    num_faceid_tokens=self.num_faceid_tokens,
                ).to(self.device, dtype=self.torch_dtype).eval()
        unet.set_attn_processor(attn_procs)
        if hasattr(self.pipe, "controlnet"):
            if isinstance(self.pipe.controlnet, ControlNetModel):
                self.pipe.controlnet.set_attn_processor(
                    UniPortraitCNAttnProcessor(
                        num_ip_tokens=self.num_ip_tokens,
                        num_faceid_tokens=self.num_faceid_tokens,
                    )
                )
            elif isinstance(self.pipe.controlnet, MultiControlNetModel):
                for module in self.pipe.controlnet.nets:
                    module.set_attn_processor(
                        UniPortraitCNAttnProcessor(
                            num_ip_tokens=self.num_ip_tokens,
                            num_faceid_tokens=self.num_faceid_tokens,
                        )
                    )
            else:
                raise ValueError

    def load_uniportrait_and_ip_adapter(self):
        if self.ip_ckpt:
            print(f"loading from {self.ip_ckpt}...")
            state_dict = torch.load(self.ip_ckpt, map_location="cpu")
            self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
            ip_layers = nn.ModuleList(self.pipe.unet.attn_processors.values())
            ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)

        if self.uniportrait_faceid_ckpt:
            print(f"loading from {self.uniportrait_faceid_ckpt}...")
            state_dict = torch.load(self.uniportrait_faceid_ckpt, map_location="cpu")
            self.faceid_proj_model.load_state_dict(state_dict["faceid_proj"], strict=True)
            ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
            ip_layers.load_state_dict(state_dict["faceid_adapter"], strict=False)

            if self.uniportrait_router_ckpt:
                print(f"loading from {self.uniportrait_router_ckpt}...")
                state_dict = torch.load(self.uniportrait_router_ckpt, map_location="cpu")
                router_state_dict = {}
                for k, v in state_dict["faceid_adapter"].items():
                    if "lora." in k:
                        router_state_dict[k.replace("lora.", "multi_id_lora.")] = v
                    elif "router." in k:
                        router_state_dict[k] = v
                ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
                ip_layers.load_state_dict(router_state_dict, strict=False)

    @torch.inference_mode()
    def get_ip_embeds(self, pil_ip_image):
        ip_image = self.clip_image_processor(images=pil_ip_image, return_tensors="pt").pixel_values
        ip_image = ip_image.to(self.device, dtype=self.torch_dtype)  # (b, 3, 224, 224), values being normalized
        ip_embeds = self.clip_image_encoder(ip_image).image_embeds
        ip_prompt_embeds = self.image_proj_model(ip_embeds)
        uncond_ip_prompt_embeds = self.image_proj_model(torch.zeros_like(ip_embeds))
        return ip_prompt_embeds, uncond_ip_prompt_embeds

    @torch.inference_mode()
    def get_single_faceid_embeds(self, pil_face_images, face_structure_scale):
        face_clip_image = self.clip_image_processor(images=pil_face_images, return_tensors="pt").pixel_values
        face_clip_image = face_clip_image.to(self.device, dtype=self.torch_dtype)  # (b, 3, 224, 224)
        face_clip_embeds = self.clip_image_encoder(
            face_clip_image, output_hidden_states=True).hidden_states[-2][:, 1:]  # b, 256, 1280

        OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device,
                                        dtype=self.torch_dtype).reshape(-1, 1, 1)
        OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device,
                                       dtype=self.torch_dtype).reshape(-1, 1, 1)
        facerecog_image = face_clip_image * OPENAI_CLIP_STD + OPENAI_CLIP_MEAN  # [0, 1]
        facerecog_image = torch.clamp((facerecog_image - 0.5) / 0.5, -1, 1)  # [-1, 1]
        facerecog_image = F.interpolate(facerecog_image, size=(112, 112), mode="bilinear", align_corners=False)
        facerecog_embeds = self.facerecog_model(facerecog_image, return_mid_feats=True)[1]

        face_intrinsic_id_embeds = facerecog_embeds[-1]  # (b, 512, 7, 7)
        face_intrinsic_id_embeds = face_intrinsic_id_embeds.flatten(2).permute(0, 2, 1)  # b, 49, 512

        facerecog_structure_embeds = facerecog_embeds[:-1]  # (b, 64, 56, 56), (b, 128, 28, 28), (b, 256, 14, 14)
        facerecog_structure_embeds = torch.cat([
            F.interpolate(feat, size=(16, 16), mode="bilinear", align_corners=False)
            for feat in facerecog_structure_embeds], dim=1)  # b, 448, 16, 16
        facerecog_structure_embeds = facerecog_structure_embeds.flatten(2).permute(0, 2, 1)  # b, 256, 448
        face_structure_embeds = torch.cat([facerecog_structure_embeds, face_clip_embeds], dim=-1)  # b, 256, 1728

        uncond_face_clip_embeds = self.clip_image_encoder(
            torch.zeros_like(face_clip_image[:1]), output_hidden_states=True).hidden_states[-2][:, 1:]  # 1, 256, 1280
        uncond_face_structure_embeds = torch.cat(
            [torch.zeros_like(facerecog_structure_embeds[:1]), uncond_face_clip_embeds], dim=-1)  # 1, 256, 1728

        faceid_prompt_embeds = self.faceid_proj_model(
            face_intrinsic_id_embeds.flatten(0, 1).unsqueeze(0),
            face_structure_embeds.flatten(0, 1).unsqueeze(0),
            structure_scale=face_structure_scale,
        )  # [b, 16, 768]

        uncond_faceid_prompt_embeds = self.faceid_proj_model(
            torch.zeros_like(face_intrinsic_id_embeds[:1]),
            uncond_face_structure_embeds,
            structure_scale=face_structure_scale,
        )  # [1, 16, 768]

        return faceid_prompt_embeds, uncond_faceid_prompt_embeds

    def generate(
            self,
            prompt=None,
            negative_prompt=None,
            pil_ip_image=None,
            cond_faceids=None,
            face_structure_scale=0.0,
            seed=-1,
            guidance_scale=7.5,
            num_inference_steps=30,
            zT=None,
            **kwargs,
    ):
        """
        Args:
            prompt:
            negative_prompt:
            pil_ip_image:
            cond_faceids: [
                {
                    "refs": [PIL.Image] or PIL.Image,
                    (Optional) "mix_refs": [PIL.Image],
                    (Optional) "mix_scales": [float],
                },
                ...
            ]
            face_structure_scale:
            seed:
            guidance_scale:
            num_inference_steps:
            zT:
            **kwargs:
        Returns:
        """

        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

        with torch.inference_mode():
            prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
                prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True,
                negative_prompt=negative_prompt)
            num_prompts = prompt_embeds.shape[0]

            if pil_ip_image is not None:
                ip_prompt_embeds, uncond_ip_prompt_embeds = self.get_ip_embeds(pil_ip_image)
                ip_prompt_embeds = ip_prompt_embeds.repeat(num_prompts, 1, 1)
                uncond_ip_prompt_embeds = uncond_ip_prompt_embeds.repeat(num_prompts, 1, 1)
            else:
                ip_prompt_embeds = uncond_ip_prompt_embeds = \
                    torch.zeros_like(prompt_embeds[:, :1]).repeat(1, self.num_ip_tokens, 1)

            prompt_embeds = torch.cat([prompt_embeds, ip_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_ip_prompt_embeds], dim=1)

            if cond_faceids and len(cond_faceids) > 0:
                all_faceid_prompt_embeds = []
                all_uncond_faceid_prompt_embeds = []
                for curr_faceid_info in cond_faceids:
                    refs = curr_faceid_info["refs"]
                    faceid_prompt_embeds, uncond_faceid_prompt_embeds = \
                        self.get_single_faceid_embeds(refs, face_structure_scale)
                    if "mix_refs" in curr_faceid_info:
                        mix_refs = curr_faceid_info["mix_refs"]
                        mix_scales = curr_faceid_info["mix_scales"]

                        master_face_mix_scale = 1.0 - sum(mix_scales)
                        faceid_prompt_embeds = faceid_prompt_embeds * master_face_mix_scale
                        for mix_ref, mix_scale in zip(mix_refs, mix_scales):
                            faceid_mix_prompt_embeds, _ = self.get_single_faceid_embeds(mix_ref, face_structure_scale)
                            faceid_prompt_embeds = faceid_prompt_embeds + faceid_mix_prompt_embeds * mix_scale

                    all_faceid_prompt_embeds.append(faceid_prompt_embeds)
                    all_uncond_faceid_prompt_embeds.append(uncond_faceid_prompt_embeds)

                faceid_prompt_embeds = torch.cat(all_faceid_prompt_embeds, dim=1)
                uncond_faceid_prompt_embeds = torch.cat(all_uncond_faceid_prompt_embeds, dim=1)
                faceid_prompt_embeds = faceid_prompt_embeds.repeat(num_prompts, 1, 1)
                uncond_faceid_prompt_embeds = uncond_faceid_prompt_embeds.repeat(num_prompts, 1, 1)

                prompt_embeds = torch.cat([prompt_embeds, faceid_prompt_embeds], dim=1)
                negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_faceid_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        if zT is not None:
            h_, w_ = kwargs["image"][0].shape[-2:]
            latents = torch.randn(num_prompts, 4, h_ // 8, w_ // 8, device=self.device, generator=generator,
                                  dtype=self.pipe.unet.dtype)
            latents[0] = zT
        else:
            latents = None

        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            latents=latents,
            **kwargs,
        ).images

        return images
Download .txt
gitextract_35y_e2yv/

├── .gitignore
├── LICENSE.txt
├── README.md
├── gradio_app.py
├── requirements.txt
└── uniportrait/
    ├── __init__.py
    ├── curricular_face/
    │   ├── __init__.py
    │   ├── backbone/
    │   │   ├── __init__.py
    │   │   ├── common.py
    │   │   ├── model_irse.py
    │   │   └── model_resnet.py
    │   └── inference.py
    ├── inversion.py
    ├── resampler.py
    ├── uniportrait_attention_processor.py
    └── uniportrait_pipeline.py
Download .txt
SYMBOL INDEX (99 symbols across 10 files)

FILE: gradio_app.py
  function pad_np_bgr_image (line 69) | def pad_np_bgr_image(np_image, scale=1.25):
  function process_faceid_image (line 79) | def process_faceid_image(pil_faceid_image):
  function prepare_single_faceid_cond_kwargs (line 106) | def prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_...
  function text_to_single_id_generation_process (line 140) | def text_to_single_id_generation_process(
  function text_to_multi_id_generation_process (line 190) | def text_to_multi_id_generation_process(
  function image_to_single_id_generation_process (line 253) | def image_to_single_id_generation_process(
  function text_to_single_id_generation_block (line 320) | def text_to_single_id_generation_block():
  function text_to_multi_id_generation_block (line 389) | def text_to_multi_id_generation_block():
  function image_to_single_id_generation_block (line 477) | def image_to_single_id_generation_block():

FILE: uniportrait/curricular_face/backbone/__init__.py
  function get_model (line 24) | def get_model(key):

FILE: uniportrait/curricular_face/backbone/common.py
  function initialize_weights (line 8) | def initialize_weights(modules):
  class Flatten (line 27) | class Flatten(Module):
    method forward (line 31) | def forward(self, input):
  class SEModule (line 35) | class SEModule(Module):
    method __init__ (line 39) | def __init__(self, channels, reduction):
    method forward (line 61) | def forward(self, x):

FILE: uniportrait/curricular_face/backbone/model_irse.py
  class BasicBlockIR (line 11) | class BasicBlockIR(Module):
    method __init__ (line 15) | def __init__(self, in_channel, depth, stride):
    method forward (line 30) | def forward(self, x):
  class BottleneckIR (line 37) | class BottleneckIR(Module):
    method __init__ (line 41) | def __init__(self, in_channel, depth, stride):
    method forward (line 64) | def forward(self, x):
  class BasicBlockIRSE (line 71) | class BasicBlockIRSE(BasicBlockIR):
    method __init__ (line 73) | def __init__(self, in_channel, depth, stride):
  class BottleneckIRSE (line 78) | class BottleneckIRSE(BottleneckIR):
    method __init__ (line 80) | def __init__(self, in_channel, depth, stride):
  class Bottleneck (line 85) | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
  function get_block (line 89) | def get_block(in_channel, depth, num_units, stride=2):
  function get_blocks (line 94) | def get_blocks(num_layers):
  class Backbone (line 141) | class Backbone(Module):
    method __init__ (line 143) | def __init__(self, input_size, num_layers, mode='ir'):
    method forward (line 200) | def forward(self, x, return_mid_feats=False):
  function IR_18 (line 216) | def IR_18(input_size):
  function IR_34 (line 224) | def IR_34(input_size):
  function IR_50 (line 232) | def IR_50(input_size):
  function IR_101 (line 240) | def IR_101(input_size):
  function IR_152 (line 248) | def IR_152(input_size):
  function IR_200 (line 256) | def IR_200(input_size):
  function IR_SE_50 (line 264) | def IR_SE_50(input_size):
  function IR_SE_101 (line 272) | def IR_SE_101(input_size):
  function IR_SE_152 (line 280) | def IR_SE_152(input_size):
  function IR_SE_200 (line 288) | def IR_SE_200(input_size):

FILE: uniportrait/curricular_face/backbone/model_resnet.py
  function conv3x3 (line 10) | def conv3x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 22) | def conv1x1(in_planes, out_planes, stride=1):
  class Bottleneck (line 29) | class Bottleneck(Module):
    method __init__ (line 32) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 44) | def forward(self, x):
  class ResNet (line 67) | class ResNet(Module):
    method __init__ (line 71) | def __init__(self, input_size, block, layers, zero_init_residual=True):
    method _make_layer (line 105) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 121) | def forward(self, x):
  function ResNet_50 (line 141) | def ResNet_50(input_size, **kwargs):
  function ResNet_101 (line 149) | def ResNet_101(input_size, **kwargs):
  function ResNet_152 (line 157) | def ResNet_152(input_size, **kwargs):

FILE: uniportrait/curricular_face/inference.py
  function inference (line 13) | def inference(name, weight, src_norm_dir):

FILE: uniportrait/inversion.py
  function _encode_text_with_negative (line 16) | def _encode_text_with_negative(model: StableDiffusionPipeline, prompt: s...
  function _encode_image (line 24) | def _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T:
  function _next_step (line 33) | def _next_step(model: StableDiffusionPipeline, model_output: T, timestep...
  function _get_noise_pred (line 46) | def _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, con...
  function _ddim_loop (line 55) | def _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scal...
  function make_inversion_callback (line 69) | def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallba...
  function ddim_inversion (line 80) | def ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, promp...

FILE: uniportrait/resampler.py
  function FeedForward (line 11) | def FeedForward(dim, mult=4):
  function reshape_tensor (line 21) | def reshape_tensor(x, heads):
  class PerceiverAttention (line 32) | class PerceiverAttention(nn.Module):
    method __init__ (line 33) | def __init__(self, *, dim, dim_head=64, heads=8):
    method forward (line 47) | def forward(self, x, latents, attention_mask=None):
  class UniPortraitFaceIDResampler (line 89) | class UniPortraitFaceIDResampler(torch.nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 132) | def forward(

FILE: uniportrait/uniportrait_attention_processor.py
  class AttentionArgs (line 8) | class AttentionArgs(object):
    method __init__ (line 9) | def __init__(self) -> None:
    method reset (line 27) | def reset(self):
    method __repr__ (line 45) | def __repr__(self):
  function expand_first (line 54) | def expand_first(feat, scale=1., ):
  function concat_first (line 65) | def concat_first(feat, dim=2, scale=1.):
  function calc_mean_std (line 70) | def calc_mean_std(feat, eps: float = 1e-5):
  function adain (line 76) | def adain(feat):
  class UniPortraitLoRAAttnProcessor2_0 (line 85) | class UniPortraitLoRAAttnProcessor2_0(nn.Module):
    method __init__ (line 87) | def __init__(
    method __call__ (line 106) | def __call__(
  class UniPortraitLoRAIPAttnProcessor2_0 (line 206) | class UniPortraitLoRAIPAttnProcessor2_0(nn.Module):
    method __init__ (line 208) | def __init__(self, hidden_size, cross_attention_dim=None, rank=128, ne...
    method __call__ (line 243) | def __call__(
  class UniPortraitCNAttnProcessor2_0 (line 422) | class UniPortraitCNAttnProcessor2_0:
    method __init__ (line 423) | def __init__(self, num_ip_tokens=4, num_faceid_tokens=16):
    method __call__ (line 428) | def __call__(

FILE: uniportrait/uniportrait_pipeline.py
  class ImageProjModel (line 15) | class ImageProjModel(nn.Module):
    method __init__ (line 18) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 26) | def forward(self, image_embeds):
  class UniPortraitPipeline (line 34) | class UniPortraitPipeline:
    method __init__ (line 36) | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_bac...
    method init_image_proj (line 73) | def init_image_proj(self):
    method init_faceid_proj (line 81) | def init_faceid_proj(self):
    method set_uniportrait_and_ip_adapter (line 92) | def set_uniportrait_and_ip_adapter(self):
    method load_uniportrait_and_ip_adapter (line 139) | def load_uniportrait_and_ip_adapter(self):
    method get_ip_embeds (line 167) | def get_ip_embeds(self, pil_ip_image):
    method get_single_faceid_embeds (line 176) | def get_single_faceid_embeds(self, pil_face_images, face_structure_sca...
    method generate (line 220) | def generate(
Condensed preview — 16 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (120K chars).
[
  {
    "path": ".gitignore",
    "chars": 1927,
    "preview": ".idea/\n.DS_Store\n*.dat\n*.mat\n\ntraining/\nlightning_logs/\nimage_log/\n\n*.png\n*.jpg\n*.jpeg\n*.webp\n\n*.pth\n*.pt\n*.ckpt\n*.safet"
  },
  {
    "path": "LICENSE.txt",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 3126,
    "preview": "<div align=\"center\">\n<h1>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personal"
  },
  {
    "path": "gradio_app.py",
    "chars": 32564,
    "preview": "import os\nfrom io import BytesIO\n\nimport cv2\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom PIL import Image\nf"
  },
  {
    "path": "requirements.txt",
    "chars": 69,
    "preview": "diffusers\ngradio\nonnxruntime-gpu\ninsightface\ntorch\ntqdm\ntransformers\n"
  },
  {
    "path": "uniportrait/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "uniportrait/curricular_face/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "uniportrait/curricular_face/backbone/__init__.py",
    "chars": 1080,
    "preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
  },
  {
    "path": "uniportrait/curricular_face/backbone/common.py",
    "chars": 1930,
    "preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
  },
  {
    "path": "uniportrait/curricular_face/backbone/model_irse.py",
    "chars": 9232,
    "preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
  },
  {
    "path": "uniportrait/curricular_face/backbone/model_resnet.py",
    "chars": 4745,
    "preview": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/T"
  },
  {
    "path": "uniportrait/curricular_face/inference.py",
    "chars": 1845,
    "preview": "import glob\nimport os\n\nimport cv2\nimport numpy as np\nimport torch\nfrom tqdm.auto import tqdm\n\nfrom .backbone import get_"
  },
  {
    "path": "uniportrait/inversion.py",
    "chars": 3836,
    "preview": "# modified from https://github.com/google/style-aligned/blob/main/inversion.py\n\nfrom __future__ import annotations\n\nfrom"
  },
  {
    "path": "uniportrait/resampler.py",
    "chars": 5328,
    "preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
  },
  {
    "path": "uniportrait/uniportrait_attention_processor.py",
    "chars": 22528,
    "preview": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport to"
  },
  {
    "path": "uniportrait/uniportrait_pipeline.py",
    "chars": 16100,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers import ControlNetModel\nfrom diffusers."
  }
]

About this extraction

This page contains the full source code of the junjiehe96/UniPortrait GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 16 files (113.0 KB), approximately 28.5k tokens, and a symbol index with 99 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!