[
  {
    "path": ".gitignore",
    "content": ".idea/\n.DS_Store\n*.dat\n*.mat\n\ntraining/\nlightning_logs/\nimage_log/\n\n*.png\n*.jpg\n*.jpeg\n*.webp\n\n*.pth\n*.pt\n*.ckpt\n*.safetensors\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<h1>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization</h1>\n\n<a href='https://aigcdesigngroup.github.io/UniPortrait-Page/'><img src='https://img.shields.io/badge/Project-Page-green'></a>\n<a href='https://arxiv.org/abs/2408.05939'><img src='https://img.shields.io/badge/Technique-Report-red'></a>\n<a href='https://huggingface.co/spaces/Junjie96/UniPortrait'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>\n\n</div>\n\n<img src='assets/highlight.png'>\n\nUniPortrait is an innovative human image personalization framework. It customizes single- and multi-ID images in a\nunified manner, providing high-fidelity identity preservation, extensive facial editability, free-form text description,\nand no requirement for a predetermined layout.\n\n---\n\n## Release\n\n- [2025/05/01] 🔥 We release the code and demo for the `FLUX.1-dev` version of [AnyStory](https://github.com/junjiehe96/AnyStory), a unified approach to general subject personalization.\n- [2024/10/18] 🔥 We release the inference code and demo, which has simply\n  integrated [ControlNet](https://github.com/lllyasviel/ControlNet)\n  , [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter),\n  and [StyleAligned](https://github.com/google/style-aligned). The weight for this version is consistent with the\n  huggingface space and experiments in the paper. We are now working on generalizing our method to more advanced\n  diffusion models and more general custom concepts. Please stay tuned!\n- [2024/08/12] 🔥 We release the [technical report](https://arxiv.org/abs/2408.05939)\n  , [project page](https://aigcdesigngroup.github.io/UniPortrait-Page/),\n  and [HuggingFace demo](https://huggingface.co/spaces/Junjie96/UniPortrait) 🤗!\n\n## Quickstart\n\n```shell\n# Clone repository\ngit clone https://github.com/junjiehe96/UniPortrait.git\n\n# install requirements\ncd UniPortrait\npip install -r requirements.txt\n\n# download the models\ngit lfs install\ngit clone https://huggingface.co/Junjie96/UniPortrait models\n# download ip-adapter models \n# Note: recommend downloading manually. We do not require all IP adapter models.\ngit clone https://huggingface.co/h94/IP-Adapter models/IP-Adapter\n\n# then you can use the gradio app\npython gradio_app.py\n```\n\n## Applications\n\n<img src='assets/application.png'>\n\n## **Acknowledgements**\n\nThis code is built on some excellent repos, including [diffusers](https://github.com/huggingface/diffusers), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [StyleAligned](https://github.com/google/style-aligned). Highly appreciate their great work!\n\n## Cite\n\nIf you find UniPortrait useful for your research and applications, please cite us using this BibTeX:\n\n```bibtex\n@article{he2024uniportrait,\n    title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},\n    author={He, Junjie and Geng, Yifeng and Bo, Liefeng},\n    journal={arXiv preprint arXiv:2408.05939},\n    year={2024}\n}\n```\n\nFor any question, please feel free to open an issue or contact us via hejunjie1103@gmail.com.\n"
  },
  {
    "path": "gradio_app.py",
    "content": "import os\nfrom io import BytesIO\n\nimport cv2\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom diffusers import DDIMScheduler, AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline\nfrom insightface.app import FaceAnalysis\nfrom insightface.utils import face_align\n\nfrom uniportrait import inversion\nfrom uniportrait.uniportrait_attention_processor import attn_args\nfrom uniportrait.uniportrait_pipeline import UniPortraitPipeline\n\nport = 7860\n\ndevice = \"cuda\"\ntorch_dtype = torch.float16\n\n# base\nbase_model_path = \"SG161222/Realistic_Vision_V5.1_noVAE\"\nvae_model_path = \"stabilityai/sd-vae-ft-mse\"\ncontrolnet_pose_ckpt = \"lllyasviel/control_v11p_sd15_openpose\"\n# specific\nimage_encoder_path = \"models/IP-Adapter/models/image_encoder\"\nip_ckpt = \"models/IP-Adapter/models/ip-adapter_sd15.bin\"\nface_backbone_ckpt = \"models/glint360k_curricular_face_r101_backbone.bin\"\nuniportrait_faceid_ckpt = \"models/uniportrait-faceid_sd15.bin\"\nuniportrait_router_ckpt = \"models/uniportrait-router_sd15.bin\"\n\n# load controlnet\npose_controlnet = ControlNetModel.from_pretrained(controlnet_pose_ckpt, torch_dtype=torch_dtype)\n\n# load SD pipeline\nnoise_scheduler = DDIMScheduler(\n    num_train_timesteps=1000,\n    beta_start=0.00085,\n    beta_end=0.012,\n    beta_schedule=\"scaled_linear\",\n    clip_sample=False,\n    set_alpha_to_one=False,\n    steps_offset=1,\n)\nvae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch_dtype)\npipe = StableDiffusionControlNetPipeline.from_pretrained(\n    base_model_path,\n    controlnet=[pose_controlnet],\n    torch_dtype=torch_dtype,\n    scheduler=noise_scheduler,\n    vae=vae,\n    # feature_extractor=None,\n    # safety_checker=None,\n)\n\n# load uniportrait pipeline\nuniportrait_pipeline = UniPortraitPipeline(pipe, image_encoder_path, ip_ckpt=ip_ckpt,\n                                           face_backbone_ckpt=face_backbone_ckpt,\n                                           uniportrait_faceid_ckpt=uniportrait_faceid_ckpt,\n                                           uniportrait_router_ckpt=uniportrait_router_ckpt,\n                                           device=device, torch_dtype=torch_dtype)\n\n# load face detection assets\nface_app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=[\"detection\"])\nface_app.prepare(ctx_id=0, det_size=(640, 640))\n\n\ndef pad_np_bgr_image(np_image, scale=1.25):\n    assert scale >= 1.0, \"scale should be >= 1.0\"\n    pad_scale = scale - 1.0\n    h, w = np_image.shape[:2]\n    top = bottom = int(h * pad_scale)\n    left = right = int(w * pad_scale)\n    ret = cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128))\n    return ret, (left, top)\n\n\ndef process_faceid_image(pil_faceid_image):\n    np_faceid_image = np.array(pil_faceid_image.convert(\"RGB\"))\n    img = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR)\n    faces = face_app.get(img)  # bgr\n    if len(faces) == 0:\n        # padding, try again\n        _h, _w = img.shape[:2]\n        _img, left_top_coord = pad_np_bgr_image(img)\n        faces = face_app.get(_img)\n        if len(faces) == 0:\n            gr.Info(\"Warning: No face detected in the image. Continue processing...\")\n\n        min_coord = np.array([0, 0])\n        max_coord = np.array([_w, _h])\n        sub_coord = np.array([left_top_coord[0], left_top_coord[1]])\n        for face in faces:\n            face.bbox = np.minimum(np.maximum(face.bbox.reshape(-1, 2) - sub_coord, min_coord), max_coord).reshape(4)\n            face.kps = face.kps - sub_coord\n\n    faces = sorted(faces, key=lambda x: abs((x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])), reverse=True)\n    faceid_face = faces[0]\n    norm_face = face_align.norm_crop(img, landmark=faceid_face.kps, image_size=224)\n    pil_faceid_align_image = Image.fromarray(cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB))\n\n    return pil_faceid_align_image\n\n\ndef prepare_single_faceid_cond_kwargs(pil_faceid_image=None, pil_faceid_supp_images=None,\n                                      pil_faceid_mix_images=None, mix_scales=None):\n    pil_faceid_align_images = []\n    if pil_faceid_image:\n        pil_faceid_align_images.append(process_faceid_image(pil_faceid_image))\n    if pil_faceid_supp_images and len(pil_faceid_supp_images) > 0:\n        for pil_faceid_supp_image in pil_faceid_supp_images:\n            if isinstance(pil_faceid_supp_image, Image.Image):\n                pil_faceid_align_images.append(process_faceid_image(pil_faceid_supp_image))\n            else:\n                pil_faceid_align_images.append(\n                    process_faceid_image(Image.open(BytesIO(pil_faceid_supp_image)))\n                )\n\n    mix_refs = []\n    mix_ref_scales = []\n    if pil_faceid_mix_images:\n        for pil_faceid_mix_image, mix_scale in zip(pil_faceid_mix_images, mix_scales):\n            if pil_faceid_mix_image:\n                mix_refs.append(process_faceid_image(pil_faceid_mix_image))\n                mix_ref_scales.append(mix_scale)\n\n    single_faceid_cond_kwargs = None\n    if len(pil_faceid_align_images) > 0:\n        single_faceid_cond_kwargs = {\n            \"refs\": pil_faceid_align_images\n        }\n        if len(mix_refs) > 0:\n            single_faceid_cond_kwargs[\"mix_refs\"] = mix_refs\n            single_faceid_cond_kwargs[\"mix_scales\"] = mix_ref_scales\n\n    return single_faceid_cond_kwargs\n\n\ndef text_to_single_id_generation_process(\n        pil_faceid_image=None, pil_faceid_supp_images=None,\n        pil_faceid_mix_image_1=None, mix_scale_1=0.0,\n        pil_faceid_mix_image_2=None, mix_scale_2=0.0,\n        faceid_scale=0.0, face_structure_scale=0.0,\n        prompt=\"\", negative_prompt=\"\",\n        num_samples=1, seed=-1,\n        image_resolution=\"512x512\",\n        inference_steps=25,\n):\n    if seed == -1:\n        seed = None\n\n    single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,\n                                                                  pil_faceid_supp_images,\n                                                                  [pil_faceid_mix_image_1, pil_faceid_mix_image_2],\n                                                                  [mix_scale_1, mix_scale_2])\n\n    cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []\n\n    # reset attn args\n    attn_args.reset()\n    # set faceid condition\n    attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0  # single-faceid lora\n    attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0  # multi-faceid lora\n    attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0\n    attn_args.num_faceids = len(cond_faceids)\n    print(attn_args)\n\n    h, w = int(image_resolution.split(\"x\")[0]), int(image_resolution.split(\"x\")[1])\n    prompt = [prompt] * num_samples\n    negative_prompt = [negative_prompt] * num_samples\n    images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,\n                                           cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,\n                                           seed=seed, guidance_scale=7.5,\n                                           num_inference_steps=inference_steps,\n                                           image=[torch.zeros([1, 3, h, w])],\n                                           controlnet_conditioning_scale=[0.0])\n    final_out = []\n    for pil_image in images:\n        final_out.append(pil_image)\n\n    for single_faceid_cond_kwargs in cond_faceids:\n        final_out.extend(single_faceid_cond_kwargs[\"refs\"])\n        if \"mix_refs\" in single_faceid_cond_kwargs:\n            final_out.extend(single_faceid_cond_kwargs[\"mix_refs\"])\n\n    return final_out\n\n\ndef text_to_multi_id_generation_process(\n        pil_faceid_image_1=None, pil_faceid_supp_images_1=None,\n        pil_faceid_mix_image_1_1=None, mix_scale_1_1=0.0,\n        pil_faceid_mix_image_1_2=None, mix_scale_1_2=0.0,\n        pil_faceid_image_2=None, pil_faceid_supp_images_2=None,\n        pil_faceid_mix_image_2_1=None, mix_scale_2_1=0.0,\n        pil_faceid_mix_image_2_2=None, mix_scale_2_2=0.0,\n        faceid_scale=0.0, face_structure_scale=0.0,\n        prompt=\"\", negative_prompt=\"\",\n        num_samples=1, seed=-1,\n        image_resolution=\"512x512\",\n        inference_steps=25,\n):\n    if seed == -1:\n        seed = None\n\n    faceid_cond_kwargs_1 = prepare_single_faceid_cond_kwargs(pil_faceid_image_1,\n                                                             pil_faceid_supp_images_1,\n                                                             [pil_faceid_mix_image_1_1,\n                                                              pil_faceid_mix_image_1_2],\n                                                             [mix_scale_1_1, mix_scale_1_2])\n    faceid_cond_kwargs_2 = prepare_single_faceid_cond_kwargs(pil_faceid_image_2,\n                                                             pil_faceid_supp_images_2,\n                                                             [pil_faceid_mix_image_2_1,\n                                                              pil_faceid_mix_image_2_2],\n                                                             [mix_scale_2_1, mix_scale_2_2])\n    cond_faceids = []\n    if faceid_cond_kwargs_1 is not None:\n        cond_faceids.append(faceid_cond_kwargs_1)\n    if faceid_cond_kwargs_2 is not None:\n        cond_faceids.append(faceid_cond_kwargs_2)\n\n    # reset attn args\n    attn_args.reset()\n    # set faceid condition\n    attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0  # single-faceid lora\n    attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0  # multi-faceid lora\n    attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0\n    attn_args.num_faceids = len(cond_faceids)\n    print(attn_args)\n\n    h, w = int(image_resolution.split(\"x\")[0]), int(image_resolution.split(\"x\")[1])\n    prompt = [prompt] * num_samples\n    negative_prompt = [negative_prompt] * num_samples\n    images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,\n                                           cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,\n                                           seed=seed, guidance_scale=7.5,\n                                           num_inference_steps=inference_steps,\n                                           image=[torch.zeros([1, 3, h, w])],\n                                           controlnet_conditioning_scale=[0.0])\n\n    final_out = []\n    for pil_image in images:\n        final_out.append(pil_image)\n\n    for single_faceid_cond_kwargs in cond_faceids:\n        final_out.extend(single_faceid_cond_kwargs[\"refs\"])\n        if \"mix_refs\" in single_faceid_cond_kwargs:\n            final_out.extend(single_faceid_cond_kwargs[\"mix_refs\"])\n\n    return final_out\n\n\ndef image_to_single_id_generation_process(\n        pil_faceid_image=None, pil_faceid_supp_images=None,\n        pil_faceid_mix_image_1=None, mix_scale_1=0.0,\n        pil_faceid_mix_image_2=None, mix_scale_2=0.0,\n        faceid_scale=0.0, face_structure_scale=0.0,\n        pil_ip_image=None, ip_scale=1.0,\n        num_samples=1, seed=-1, image_resolution=\"768x512\",\n        inference_steps=25,\n):\n    if seed == -1:\n        seed = None\n\n    single_faceid_cond_kwargs = prepare_single_faceid_cond_kwargs(pil_faceid_image,\n                                                                  pil_faceid_supp_images,\n                                                                  [pil_faceid_mix_image_1, pil_faceid_mix_image_2],\n                                                                  [mix_scale_1, mix_scale_2])\n\n    cond_faceids = [single_faceid_cond_kwargs] if single_faceid_cond_kwargs else []\n\n    h, w = int(image_resolution.split(\"x\")[0]), int(image_resolution.split(\"x\")[1])\n\n    # Image Prompt and Style Aligned\n    if pil_ip_image is None:\n        gr.Error(\"Please upload a reference image\")\n    attn_args.reset()\n    pil_ip_image = pil_ip_image.convert(\"RGB\").resize((w, h))\n    zts = inversion.ddim_inversion(uniportrait_pipeline.pipe, np.array(pil_ip_image), \"\", inference_steps, 2)\n    zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0)\n\n    # reset attn args\n    attn_args.reset()\n    # set ip condition\n    attn_args.ip_scale = ip_scale if pil_ip_image else 0.0\n    # set faceid condition\n    attn_args.lora_scale = 1.0 if len(cond_faceids) == 1 else 0.0  # lora for single faceid\n    attn_args.multi_id_lora_scale = 1.0 if len(cond_faceids) > 1 else 0.0  # lora for >1 faceids\n    attn_args.faceid_scale = faceid_scale if len(cond_faceids) > 0 else 0.0\n    attn_args.num_faceids = len(cond_faceids)\n    # set shared self-attn\n    attn_args.enable_share_attn = True\n    attn_args.shared_score_shift = -0.5\n    print(attn_args)\n\n    prompt = [\"\"] * (1 + num_samples)\n    negative_prompt = [\"\"] * (1 + num_samples)\n    images = uniportrait_pipeline.generate(prompt=prompt, negative_prompt=negative_prompt,\n                                           pil_ip_image=pil_ip_image,\n                                           cond_faceids=cond_faceids, face_structure_scale=face_structure_scale,\n                                           seed=seed, guidance_scale=7.5,\n                                           num_inference_steps=inference_steps,\n                                           image=[torch.zeros([1, 3, h, w])],\n                                           controlnet_conditioning_scale=[0.0],\n                                           zT=zT, callback_on_step_end=inversion_callback)\n    images = images[1:]\n\n    final_out = []\n    for pil_image in images:\n        final_out.append(pil_image)\n\n    for single_faceid_cond_kwargs in cond_faceids:\n        final_out.extend(single_faceid_cond_kwargs[\"refs\"])\n        if \"mix_refs\" in single_faceid_cond_kwargs:\n            final_out.extend(single_faceid_cond_kwargs[\"mix_refs\"])\n\n    return final_out\n\n\ndef text_to_single_id_generation_block():\n    gr.Markdown(\"## Text-to-Single-ID Generation\")\n    gr.HTML(text_to_single_id_description)\n    gr.HTML(text_to_single_id_tips)\n    with gr.Row():\n        with gr.Column(scale=1, min_width=100):\n            prompt = gr.Textbox(value=\"\", label='Prompt', lines=2)\n            negative_prompt = gr.Textbox(value=\"nsfw\", label='Negative Prompt')\n\n            run_button = gr.Button(value=\"Run\")\n            with gr.Accordion(\"Options\", open=True):\n                image_resolution = gr.Dropdown(choices=[\"768x512\", \"512x512\", \"512x768\"], value=\"512x512\",\n                                               label=\"Image Resolution (HxW)\")\n                seed = gr.Slider(label=\"Seed (-1 indicates random)\", minimum=-1, maximum=2147483647, step=1,\n                                 value=2147483647)\n                num_samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=2, step=1)\n                inference_steps = gr.Slider(label=\"Steps\", minimum=1, maximum=100, value=25, step=1, visible=False)\n\n                faceid_scale = gr.Slider(label=\"Face ID Scale\", minimum=0.0, maximum=1.0, step=0.01, value=0.7)\n                face_structure_scale = gr.Slider(label=\"Face Structure Scale\", minimum=0.0, maximum=1.0,\n                                                 step=0.01, value=0.1)\n\n        with gr.Column(scale=2, min_width=100):\n            with gr.Row(equal_height=False):\n                pil_faceid_image = gr.Image(type=\"pil\", label=\"ID Image\")\n                with gr.Accordion(\"ID Supplements\", open=True):\n                    with gr.Row():\n                        pil_faceid_supp_images = gr.File(file_count=\"multiple\", file_types=[\"image\"],\n                                                         type=\"binary\", label=\"Additional ID Images\")\n                    with gr.Row():\n                        with gr.Column(scale=1, min_width=100):\n                            pil_faceid_mix_image_1 = gr.Image(type=\"pil\", label=\"Mix ID 1\")\n                            mix_scale_1 = gr.Slider(label=\"Mix Scale 1\", minimum=0.0, maximum=1.0, step=0.01, value=0.0)\n                        with gr.Column(scale=1, min_width=100):\n                            pil_faceid_mix_image_2 = gr.Image(type=\"pil\", label=\"Mix ID 2\")\n                            mix_scale_2 = gr.Slider(label=\"Mix Scale 2\", minimum=0.0, maximum=1.0, step=0.01, value=0.0)\n\n            with gr.Row():\n                example_output = gr.Image(type=\"pil\", label=\"(Example Output)\", visible=False)\n                result_gallery = gr.Gallery(label='Output', show_label=True, elem_id=\"gallery\", columns=4, preview=True,\n                                            format=\"png\")\n    with gr.Row():\n        examples = [\n            [\n                \"A young man with short black hair, wearing a black hoodie with a hood, was paired with a blue denim jacket with yellow details.\",\n                \"assets/examples/1-newton.jpg\",\n                \"assets/examples/1-output-1.png\",\n            ],\n        ]\n        gr.Examples(\n            label=\"Examples\",\n            examples=examples,\n            fn=lambda x, y, z: (x, y),\n            inputs=[prompt, pil_faceid_image, example_output],\n            outputs=[prompt, pil_faceid_image]\n        )\n    ips = [\n        pil_faceid_image, pil_faceid_supp_images,\n        pil_faceid_mix_image_1, mix_scale_1,\n        pil_faceid_mix_image_2, mix_scale_2,\n        faceid_scale, face_structure_scale,\n        prompt, negative_prompt,\n        num_samples, seed,\n        image_resolution,\n        inference_steps,\n    ]\n    run_button.click(fn=text_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])\n\n\ndef text_to_multi_id_generation_block():\n    gr.Markdown(\"## Text-to-Multi-ID Generation\")\n    gr.HTML(text_to_multi_id_description)\n    gr.HTML(text_to_multi_id_tips)\n    with gr.Row():\n        with gr.Column(scale=1, min_width=100):\n            prompt = gr.Textbox(value=\"\", label='Prompt', lines=2)\n            negative_prompt = gr.Textbox(value=\"nsfw\", label='Negative Prompt')\n            run_button = gr.Button(value=\"Run\")\n            with gr.Accordion(\"Options\", open=True):\n                image_resolution = gr.Dropdown(choices=[\"768x512\", \"512x512\", \"512x768\"], value=\"512x512\",\n                                               label=\"Image Resolution (HxW)\")\n                seed = gr.Slider(label=\"Seed (-1 indicates random)\", minimum=-1, maximum=2147483647, step=1,\n                                 value=2147483647)\n                num_samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=2, step=1)\n                inference_steps = gr.Slider(label=\"Steps\", minimum=1, maximum=100, value=25, step=1, visible=False)\n\n                faceid_scale = gr.Slider(label=\"Face ID Scale\", minimum=0.0, maximum=1.0, step=0.01, value=0.7)\n                face_structure_scale = gr.Slider(label=\"Face Structure Scale\", minimum=0.0, maximum=1.0,\n                                                 step=0.01, value=0.3)\n\n        with gr.Column(scale=2, min_width=100):\n            with gr.Row(equal_height=False):\n                with gr.Column(scale=1, min_width=100):\n                    pil_faceid_image_1 = gr.Image(type=\"pil\", label=\"First ID\")\n                    with gr.Accordion(\"First ID Supplements\", open=False):\n                        with gr.Row():\n                            pil_faceid_supp_images_1 = gr.File(file_count=\"multiple\", file_types=[\"image\"],\n                                                               type=\"binary\", label=\"Additional ID Images\")\n                        with gr.Row():\n                            with gr.Column(scale=1, min_width=100):\n                                pil_faceid_mix_image_1_1 = gr.Image(type=\"pil\", label=\"Mix ID 1\")\n                                mix_scale_1_1 = gr.Slider(label=\"Mix Scale 1\", minimum=0.0, maximum=1.0, step=0.01,\n                                                          value=0.0)\n                            with gr.Column(scale=1, min_width=100):\n                                pil_faceid_mix_image_1_2 = gr.Image(type=\"pil\", label=\"Mix ID 2\")\n                                mix_scale_1_2 = gr.Slider(label=\"Mix Scale 2\", minimum=0.0, maximum=1.0, step=0.01,\n                                                          value=0.0)\n                with gr.Column(scale=1, min_width=100):\n                    pil_faceid_image_2 = gr.Image(type=\"pil\", label=\"Second ID\")\n                    with gr.Accordion(\"Second ID Supplements\", open=False):\n                        with gr.Row():\n                            pil_faceid_supp_images_2 = gr.File(file_count=\"multiple\", file_types=[\"image\"],\n                                                               type=\"binary\", label=\"Additional ID Images\")\n                        with gr.Row():\n                            with gr.Column(scale=1, min_width=100):\n                                pil_faceid_mix_image_2_1 = gr.Image(type=\"pil\", label=\"Mix ID 1\")\n                                mix_scale_2_1 = gr.Slider(label=\"Mix Scale 1\", minimum=0.0, maximum=1.0, step=0.01,\n                                                          value=0.0)\n                            with gr.Column(scale=1, min_width=100):\n                                pil_faceid_mix_image_2_2 = gr.Image(type=\"pil\", label=\"Mix ID 2\")\n                                mix_scale_2_2 = gr.Slider(label=\"Mix Scale 2\", minimum=0.0, maximum=1.0, step=0.01,\n                                                          value=0.0)\n\n            with gr.Row():\n                example_output = gr.Image(type=\"pil\", label=\"(Example Output)\", visible=False)\n                result_gallery = gr.Gallery(label='Output', show_label=True, elem_id=\"gallery\", columns=4, preview=True,\n                                            format=\"png\")\n    with gr.Row():\n        examples = [\n            [\n                \"The two female models, fair-skinned, wore a white V-neck short-sleeved top with a light smile on the corners of their mouths. The background was off-white.\",\n                \"assets/examples/2-stylegan2-ffhq-0100.png\",\n                \"assets/examples/2-stylegan2-ffhq-0293.png\",\n                \"assets/examples/2-output-1.png\",\n            ],\n        ]\n        gr.Examples(\n            label=\"Examples\",\n            examples=examples,\n            inputs=[prompt, pil_faceid_image_1, pil_faceid_image_2, example_output],\n        )\n    ips = [\n        pil_faceid_image_1, pil_faceid_supp_images_1,\n        pil_faceid_mix_image_1_1, mix_scale_1_1,\n        pil_faceid_mix_image_1_2, mix_scale_1_2,\n        pil_faceid_image_2, pil_faceid_supp_images_2,\n        pil_faceid_mix_image_2_1, mix_scale_2_1,\n        pil_faceid_mix_image_2_2, mix_scale_2_2,\n        faceid_scale, face_structure_scale,\n        prompt, negative_prompt,\n        num_samples, seed,\n        image_resolution,\n        inference_steps,\n    ]\n    run_button.click(fn=text_to_multi_id_generation_process, inputs=ips, outputs=[result_gallery])\n\n\ndef image_to_single_id_generation_block():\n    gr.Markdown(\"## Image-to-Single-ID Generation\")\n    gr.HTML(image_to_single_id_description)\n    gr.HTML(image_to_single_id_tips)\n    with gr.Row():\n        with gr.Column(scale=1, min_width=100):\n            run_button = gr.Button(value=\"Run\")\n            seed = gr.Slider(label=\"Seed (-1 indicates random)\", minimum=-1, maximum=2147483647, step=1,\n                             value=2147483647)\n            num_samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=2, step=1)\n            image_resolution = gr.Dropdown(choices=[\"768x512\", \"512x512\", \"512x768\"], value=\"512x512\",\n                                           label=\"Image Resolution (HxW)\")\n            inference_steps = gr.Slider(label=\"Steps\", minimum=1, maximum=100, value=25, step=1, visible=False)\n\n            ip_scale = gr.Slider(label=\"Reference Scale\", minimum=0.0, maximum=1.0, step=0.01, value=0.7)\n            faceid_scale = gr.Slider(label=\"Face ID Scale\", minimum=0.0, maximum=1.0, step=0.01, value=0.7)\n            face_structure_scale = gr.Slider(label=\"Face Structure Scale\", minimum=0.0, maximum=1.0, step=0.01,\n                                             value=0.3)\n\n        with gr.Column(scale=3, min_width=100):\n            with gr.Row(equal_height=False):\n                pil_ip_image = gr.Image(type=\"pil\", label=\"Portrait Reference\")\n                pil_faceid_image = gr.Image(type=\"pil\", label=\"ID Image\")\n                with gr.Accordion(\"ID Supplements\", open=True):\n                    with gr.Row():\n                        pil_faceid_supp_images = gr.File(file_count=\"multiple\", file_types=[\"image\"],\n                                                         type=\"binary\", label=\"Additional ID Images\")\n                    with gr.Row():\n                        with gr.Column(scale=1, min_width=100):\n                            pil_faceid_mix_image_1 = gr.Image(type=\"pil\", label=\"Mix ID 1\")\n                            mix_scale_1 = gr.Slider(label=\"Mix Scale 1\", minimum=0.0, maximum=1.0, step=0.01, value=0.0)\n                        with gr.Column(scale=1, min_width=100):\n                            pil_faceid_mix_image_2 = gr.Image(type=\"pil\", label=\"Mix ID 2\")\n                            mix_scale_2 = gr.Slider(label=\"Mix Scale 2\", minimum=0.0, maximum=1.0, step=0.01, value=0.0)\n            with gr.Row():\n                with gr.Column(scale=3, min_width=100):\n                    example_output = gr.Image(type=\"pil\", label=\"(Example Output)\", visible=False)\n                    result_gallery = gr.Gallery(label='Output', show_label=True, elem_id=\"gallery\", columns=4,\n                                                preview=True, format=\"png\")\n    with gr.Row():\n        examples = [\n            [\n                \"assets/examples/3-style-1.png\",\n                \"assets/examples/3-stylegan2-ffhq-0293.png\",\n                0.7,\n                0.3,\n                \"assets/examples/3-output-1.png\",\n            ],\n            [\n                \"assets/examples/3-style-1.png\",\n                \"assets/examples/3-stylegan2-ffhq-0293.png\",\n                0.6,\n                0.0,\n                \"assets/examples/3-output-2.png\",\n            ],\n            [\n                \"assets/examples/3-style-2.jpg\",\n                \"assets/examples/3-stylegan2-ffhq-0381.png\",\n                0.7,\n                0.3,\n                \"assets/examples/3-output-3.png\",\n            ],\n            [\n                \"assets/examples/3-style-3.jpg\",\n                \"assets/examples/3-stylegan2-ffhq-0381.png\",\n                0.6,\n                0.0,\n                \"assets/examples/3-output-4.png\",\n            ],\n        ]\n        gr.Examples(\n            label=\"Examples\",\n            examples=examples,\n            fn=lambda x, y, z, w, v: (x, y, z, w),\n            inputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale, example_output],\n            outputs=[pil_ip_image, pil_faceid_image, faceid_scale, face_structure_scale]\n        )\n    ips = [\n        pil_faceid_image, pil_faceid_supp_images,\n        pil_faceid_mix_image_1, mix_scale_1,\n        pil_faceid_mix_image_2, mix_scale_2,\n        faceid_scale, face_structure_scale,\n        pil_ip_image, ip_scale,\n        num_samples, seed, image_resolution,\n        inference_steps,\n    ]\n    run_button.click(fn=image_to_single_id_generation_process, inputs=ips, outputs=[result_gallery])\n\n\nif __name__ == \"__main__\":\n    os.environ[\"no_proxy\"] = \"localhost,127.0.0.1,::1\"\n\n    title = r\"\"\"\n            <div style=\"text-align: center;\">\n                <h1> UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization </h1>\n                <div style=\"display: flex; justify-content: center; align-items: center; text-align: center;\">\n                    <a href=\"https://arxiv.org/pdf/2408.05939\"><img src=\"https://img.shields.io/badge/arXiv-2408.05939-red\"></a>\n                    &nbsp;\n                    <a href='https://aigcdesigngroup.github.io/UniPortrait-Page/'><img src='https://img.shields.io/badge/Project_Page-UniPortrait-green' alt='Project Page'></a>\n                    &nbsp;\n                    <a href=\"https://github.com/junjiehe96/UniPortrait\"><img src=\"https://img.shields.io/badge/Github-Code-blue\"></a>\n                </div>\n                </br>\n            </div>\n        \"\"\"\n\n    title_description = r\"\"\"\n        This is the <b>official 🤗 Gradio demo</b> for <a href='https://arxiv.org/pdf/2408.05939' target='_blank'><b>UniPortrait: A Unified Framework for Identity-Preserving Single- and Multi-Human Image Personalization</b></a>.<br>\n        The demo provides three capabilities: text-to-single-ID personalization, text-to-multi-ID personalization, and image-to-single-ID personalization. All of these are based on the <b>Stable Diffusion v1-5</b> model. Feel free to give them a try! 😊\n        \"\"\"\n\n    text_to_single_id_description = r\"\"\"🚀🚀🚀Quick start:<br>\n        1. Enter a text prompt (Chinese or English), Upload an image with a face, and Click the <b>Run</b> button. 🤗<br>\n        \"\"\"\n\n    text_to_single_id_tips = r\"\"\"💡💡💡Tips:<br>\n        1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>\n        2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the \"ID supplements\".<br>\n        3. The appropriate values of \"Face ID Scale\" and \"Face Structure Scale\" are important for balancing the ID and text alignment. We recommend using \"Face ID Scale\" (0.5~0.7) and \"Face Structure Scale\" (0.0~0.4).<br>\n        \"\"\"\n\n    text_to_multi_id_description = r\"\"\"🚀🚀🚀Quick start:<br>\n        1. Enter a text prompt (Chinese or English), Upload an image with a face in \"First ID\" and \"Second ID\" blocks respectively, and Click the <b>Run</b> button. 🤗<br>\n        \"\"\"\n\n    text_to_multi_id_tips = r\"\"\"💡💡💡Tips:<br>\n        1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>\n        2. It's a good idea to upload multiple reference photos of your face to improve the prompt and ID consistency. Additional references can be uploaded in the \"ID supplements\".<br>\n        3. The appropriate values of \"Face ID Scale\" and \"Face Structure Scale\" are important for balancing the ID and text alignment. We recommend using \"Face ID Scale\" (0.3~0.7) and \"Face Structure Scale\" (0.0~0.4).<br>\n        \"\"\"\n\n    image_to_single_id_description = r\"\"\"🚀🚀🚀Quick start: Upload an image as the portrait reference (can be any style), Upload a face image, and Click the <b>Run</b> button. 🤗<br>\"\"\"\n\n    image_to_single_id_tips = r\"\"\"💡💡💡Tips:<br>\n        1. Try to avoid creating too small faces, as this may lead to some artifacts. (Currently, the short side length of the generated image is limited to 512)<br>\n        2. It's a good idea to upload multiple reference photos of your face to improve ID consistency. Additional references can be uploaded in the \"ID supplements\".<br>\n        3. The appropriate values of \"Face ID Scale\" and \"Face Structure Scale\" are important for balancing the portrait reference and ID alignment. We recommend using \"Face ID Scale\" (0.5~0.7) and \"Face Structure Scale\" (0.0~0.4).<br>\n        \"\"\"\n\n    citation = r\"\"\"\n        ---\n        📝 **Citation**\n        <br>\n        If our work is helpful for your research or applications, please cite us via:\n        ```bibtex\n        @article{he2024uniportrait,\n          title={UniPortrait: A Unified Framework for Identity-Preserving Single-and Multi-Human Image Personalization},\n          author={He, Junjie and Geng, Yifeng and Bo, Liefeng},\n          journal={arXiv preprint arXiv:2408.05939},\n          year={2024}\n        }\n        ```\n        📧 **Contact**\n        <br>\n        If you have any questions, please feel free to open an issue or directly reach us out at <b>hejunjie1103@gmail.com</b>.\n        \"\"\"\n\n    block = gr.Blocks(title=\"UniPortrait\").queue()\n    with block:\n        gr.HTML(title)\n        gr.HTML(title_description)\n\n        with gr.TabItem(\"Text-to-Single-ID\"):\n            text_to_single_id_generation_block()\n\n        with gr.TabItem(\"Text-to-Multi-ID\"):\n            text_to_multi_id_generation_block()\n\n        with gr.TabItem(\"Image-to-Single-ID (Stylization)\"):\n            image_to_single_id_generation_block()\n\n        gr.Markdown(citation)\n\n    block.launch(server_name='0.0.0.0', share=False, server_port=port, allowed_paths=[\"/\"])\n"
  },
  {
    "path": "requirements.txt",
    "content": "diffusers\ngradio\nonnxruntime-gpu\ninsightface\ntorch\ntqdm\ntransformers\n"
  },
  {
    "path": "uniportrait/__init__.py",
    "content": ""
  },
  {
    "path": "uniportrait/curricular_face/__init__.py",
    "content": ""
  },
  {
    "path": "uniportrait/curricular_face/backbone/__init__.py",
    "content": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone\nfrom .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50,\n                         IR_SE_101, IR_SE_152, IR_SE_200)\nfrom .model_resnet import ResNet_50, ResNet_101, ResNet_152\n\n_model_dict = {\n    'ResNet_50': ResNet_50,\n    'ResNet_101': ResNet_101,\n    'ResNet_152': ResNet_152,\n    'IR_18': IR_18,\n    'IR_34': IR_34,\n    'IR_50': IR_50,\n    'IR_101': IR_101,\n    'IR_152': IR_152,\n    'IR_200': IR_200,\n    'IR_SE_50': IR_SE_50,\n    'IR_SE_101': IR_SE_101,\n    'IR_SE_152': IR_SE_152,\n    'IR_SE_200': IR_SE_200\n}\n\n\ndef get_model(key):\n    \"\"\" Get different backbone network by key,\n        support ResNet50, ResNet_101, ResNet_152\n        IR_18, IR_34, IR_50, IR_101, IR_152, IR_200,\n        IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200.\n    \"\"\"\n    if key in _model_dict.keys():\n        return _model_dict[key]\n    else:\n        raise KeyError('not support model {}'.format(key))\n"
  },
  {
    "path": "uniportrait/curricular_face/backbone/common.py",
    "content": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py\nimport torch.nn as nn\nfrom torch.nn import (Conv2d, Module, ReLU,\n                      Sigmoid)\n\n\ndef initialize_weights(modules):\n    \"\"\" Weight initilize, conv2d and linear is initialized with kaiming_normal\n    \"\"\"\n    for m in modules:\n        if isinstance(m, nn.Conv2d):\n            nn.init.kaiming_normal_(\n                m.weight, mode='fan_out', nonlinearity='relu')\n            if m.bias is not None:\n                m.bias.data.zero_()\n        elif isinstance(m, nn.BatchNorm2d):\n            m.weight.data.fill_(1)\n            m.bias.data.zero_()\n        elif isinstance(m, nn.Linear):\n            nn.init.kaiming_normal_(\n                m.weight, mode='fan_out', nonlinearity='relu')\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n\nclass Flatten(Module):\n    \"\"\" Flat tensor\n    \"\"\"\n\n    def forward(self, input):\n        return input.view(input.size(0), -1)\n\n\nclass SEModule(Module):\n    \"\"\" SE block\n    \"\"\"\n\n    def __init__(self, channels, reduction):\n        super(SEModule, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc1 = Conv2d(\n            channels,\n            channels // reduction,\n            kernel_size=1,\n            padding=0,\n            bias=False)\n\n        nn.init.xavier_uniform_(self.fc1.weight.data)\n\n        self.relu = ReLU(inplace=True)\n        self.fc2 = Conv2d(\n            channels // reduction,\n            channels,\n            kernel_size=1,\n            padding=0,\n            bias=False)\n\n        self.sigmoid = Sigmoid()\n\n    def forward(self, x):\n        module_input = x\n        x = self.avg_pool(x)\n        x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        x = self.sigmoid(x)\n\n        return module_input * x\n"
  },
  {
    "path": "uniportrait/curricular_face/backbone/model_irse.py",
    "content": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py\nfrom collections import namedtuple\n\nfrom torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,\n                      MaxPool2d, Module, PReLU, Sequential)\n\nfrom .common import Flatten, SEModule, initialize_weights\n\n\nclass BasicBlockIR(Module):\n    \"\"\" BasicBlock for IRNet\n    \"\"\"\n\n    def __init__(self, in_channel, depth, stride):\n        super(BasicBlockIR, self).__init__()\n        if in_channel == depth:\n            self.shortcut_layer = MaxPool2d(1, stride)\n        else:\n            self.shortcut_layer = Sequential(\n                Conv2d(in_channel, depth, (1, 1), stride, bias=False),\n                BatchNorm2d(depth))\n        self.res_layer = Sequential(\n            BatchNorm2d(in_channel),\n            Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),\n            BatchNorm2d(depth), PReLU(depth),\n            Conv2d(depth, depth, (3, 3), stride, 1, bias=False),\n            BatchNorm2d(depth))\n\n    def forward(self, x):\n        shortcut = self.shortcut_layer(x)\n        res = self.res_layer(x)\n\n        return res + shortcut\n\n\nclass BottleneckIR(Module):\n    \"\"\" BasicBlock with bottleneck for IRNet\n    \"\"\"\n\n    def __init__(self, in_channel, depth, stride):\n        super(BottleneckIR, self).__init__()\n        reduction_channel = depth // 4\n        if in_channel == depth:\n            self.shortcut_layer = MaxPool2d(1, stride)\n        else:\n            self.shortcut_layer = Sequential(\n                Conv2d(in_channel, depth, (1, 1), stride, bias=False),\n                BatchNorm2d(depth))\n        self.res_layer = Sequential(\n            BatchNorm2d(in_channel),\n            Conv2d(\n                in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),\n            BatchNorm2d(reduction_channel), PReLU(reduction_channel),\n            Conv2d(\n                reduction_channel,\n                reduction_channel, (3, 3), (1, 1),\n                1,\n                bias=False), BatchNorm2d(reduction_channel),\n            PReLU(reduction_channel),\n            Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),\n            BatchNorm2d(depth))\n\n    def forward(self, x):\n        shortcut = self.shortcut_layer(x)\n        res = self.res_layer(x)\n\n        return res + shortcut\n\n\nclass BasicBlockIRSE(BasicBlockIR):\n\n    def __init__(self, in_channel, depth, stride):\n        super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)\n        self.res_layer.add_module('se_block', SEModule(depth, 16))\n\n\nclass BottleneckIRSE(BottleneckIR):\n\n    def __init__(self, in_channel, depth, stride):\n        super(BottleneckIRSE, self).__init__(in_channel, depth, stride)\n        self.res_layer.add_module('se_block', SEModule(depth, 16))\n\n\nclass Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):\n    '''A named tuple describing a ResNet block.'''\n\n\ndef get_block(in_channel, depth, num_units, stride=2):\n    return [Bottleneck(in_channel, depth, stride)] + \\\n           [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]\n\n\ndef get_blocks(num_layers):\n    if num_layers == 18:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=2),\n            get_block(in_channel=64, depth=128, num_units=2),\n            get_block(in_channel=128, depth=256, num_units=2),\n            get_block(in_channel=256, depth=512, num_units=2)\n        ]\n    elif num_layers == 34:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=3),\n            get_block(in_channel=64, depth=128, num_units=4),\n            get_block(in_channel=128, depth=256, num_units=6),\n            get_block(in_channel=256, depth=512, num_units=3)\n        ]\n    elif num_layers == 50:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=3),\n            get_block(in_channel=64, depth=128, num_units=4),\n            get_block(in_channel=128, depth=256, num_units=14),\n            get_block(in_channel=256, depth=512, num_units=3)\n        ]\n    elif num_layers == 100:\n        blocks = [\n            get_block(in_channel=64, depth=64, num_units=3),\n            get_block(in_channel=64, depth=128, num_units=13),\n            get_block(in_channel=128, depth=256, num_units=30),\n            get_block(in_channel=256, depth=512, num_units=3)\n        ]\n    elif num_layers == 152:\n        blocks = [\n            get_block(in_channel=64, depth=256, num_units=3),\n            get_block(in_channel=256, depth=512, num_units=8),\n            get_block(in_channel=512, depth=1024, num_units=36),\n            get_block(in_channel=1024, depth=2048, num_units=3)\n        ]\n    elif num_layers == 200:\n        blocks = [\n            get_block(in_channel=64, depth=256, num_units=3),\n            get_block(in_channel=256, depth=512, num_units=24),\n            get_block(in_channel=512, depth=1024, num_units=36),\n            get_block(in_channel=1024, depth=2048, num_units=3)\n        ]\n\n    return blocks\n\n\nclass Backbone(Module):\n\n    def __init__(self, input_size, num_layers, mode='ir'):\n        \"\"\" Args:\n            input_size: input_size of backbone\n            num_layers: num_layers of backbone\n            mode: support ir or irse\n        \"\"\"\n        super(Backbone, self).__init__()\n        assert input_size[0] in [112, 224], \\\n            'input_size should be [112, 112] or [224, 224]'\n        assert num_layers in [18, 34, 50, 100, 152, 200], \\\n            'num_layers should be 18, 34, 50, 100 or 152'\n        assert mode in ['ir', 'ir_se'], \\\n            'mode should be ir or ir_se'\n        self.input_layer = Sequential(\n            Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),\n            PReLU(64))\n        blocks = get_blocks(num_layers)\n        if num_layers <= 100:\n            if mode == 'ir':\n                unit_module = BasicBlockIR\n            elif mode == 'ir_se':\n                unit_module = BasicBlockIRSE\n            output_channel = 512\n        else:\n            if mode == 'ir':\n                unit_module = BottleneckIR\n            elif mode == 'ir_se':\n                unit_module = BottleneckIRSE\n            output_channel = 2048\n\n        if input_size[0] == 112:\n            self.output_layer = Sequential(\n                BatchNorm2d(output_channel), Dropout(0.4), Flatten(),\n                Linear(output_channel * 7 * 7, 512),\n                BatchNorm1d(512, affine=False))\n        else:\n            self.output_layer = Sequential(\n                BatchNorm2d(output_channel), Dropout(0.4), Flatten(),\n                Linear(output_channel * 14 * 14, 512),\n                BatchNorm1d(512, affine=False))\n\n        modules = []\n        mid_layer_indices = []  # [2, 15, 45, 48], total 49 layers for IR101\n        for block in blocks:\n            if len(mid_layer_indices) == 0:\n                mid_layer_indices.append(len(block) - 1)\n            else:\n                mid_layer_indices.append(len(block) + mid_layer_indices[-1])\n            for bottleneck in block:\n                modules.append(\n                    unit_module(bottleneck.in_channel, bottleneck.depth,\n                                bottleneck.stride))\n        self.body = Sequential(*modules)\n        self.mid_layer_indices = mid_layer_indices[-4:]\n\n        initialize_weights(self.modules())\n\n    def forward(self, x, return_mid_feats=False):\n        x = self.input_layer(x)\n        if not return_mid_feats:\n            x = self.body(x)\n            x = self.output_layer(x)\n            return x\n        else:\n            out_feats = []\n            for idx, module in enumerate(self.body):\n                x = module(x)\n                if idx in self.mid_layer_indices:\n                    out_feats.append(x)\n            x = self.output_layer(x)\n            return x, out_feats\n\n\ndef IR_18(input_size):\n    \"\"\" Constructs a ir-18 model.\n    \"\"\"\n    model = Backbone(input_size, 18, 'ir')\n\n    return model\n\n\ndef IR_34(input_size):\n    \"\"\" Constructs a ir-34 model.\n    \"\"\"\n    model = Backbone(input_size, 34, 'ir')\n\n    return model\n\n\ndef IR_50(input_size):\n    \"\"\" Constructs a ir-50 model.\n    \"\"\"\n    model = Backbone(input_size, 50, 'ir')\n\n    return model\n\n\ndef IR_101(input_size):\n    \"\"\" Constructs a ir-101 model.\n    \"\"\"\n    model = Backbone(input_size, 100, 'ir')\n\n    return model\n\n\ndef IR_152(input_size):\n    \"\"\" Constructs a ir-152 model.\n    \"\"\"\n    model = Backbone(input_size, 152, 'ir')\n\n    return model\n\n\ndef IR_200(input_size):\n    \"\"\" Constructs a ir-200 model.\n    \"\"\"\n    model = Backbone(input_size, 200, 'ir')\n\n    return model\n\n\ndef IR_SE_50(input_size):\n    \"\"\" Constructs a ir_se-50 model.\n    \"\"\"\n    model = Backbone(input_size, 50, 'ir_se')\n\n    return model\n\n\ndef IR_SE_101(input_size):\n    \"\"\" Constructs a ir_se-101 model.\n    \"\"\"\n    model = Backbone(input_size, 100, 'ir_se')\n\n    return model\n\n\ndef IR_SE_152(input_size):\n    \"\"\" Constructs a ir_se-152 model.\n    \"\"\"\n    model = Backbone(input_size, 152, 'ir_se')\n\n    return model\n\n\ndef IR_SE_200(input_size):\n    \"\"\" Constructs a ir_se-200 model.\n    \"\"\"\n    model = Backbone(input_size, 200, 'ir_se')\n\n    return model\n"
  },
  {
    "path": "uniportrait/curricular_face/backbone/model_resnet.py",
    "content": "# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at\n# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py\nimport torch.nn as nn\nfrom torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,\n                      MaxPool2d, Module, ReLU, Sequential)\n\nfrom .common import initialize_weights\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\" 3x3 convolution with padding\n    \"\"\"\n    return Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=1,\n        bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    \"\"\" 1x1 convolution\n    \"\"\"\n    return Conv2d(\n        in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass Bottleneck(Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = conv1x1(inplanes, planes)\n        self.bn1 = BatchNorm2d(planes)\n        self.conv2 = conv3x3(planes, planes, stride)\n        self.bn2 = BatchNorm2d(planes)\n        self.conv3 = conv1x1(planes, planes * self.expansion)\n        self.bn3 = BatchNorm2d(planes * self.expansion)\n        self.relu = ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(Module):\n    \"\"\" ResNet backbone\n    \"\"\"\n\n    def __init__(self, input_size, block, layers, zero_init_residual=True):\n        \"\"\" Args:\n            input_size: input_size of backbone\n            block: block function\n            layers: layers in each block\n        \"\"\"\n        super(ResNet, self).__init__()\n        assert input_size[0] in [112, 224], \\\n            'input_size should be [112, 112] or [224, 224]'\n        self.inplanes = 64\n        self.conv1 = Conv2d(\n            3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n        self.bn1 = BatchNorm2d(64)\n        self.relu = ReLU(inplace=True)\n        self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n\n        self.bn_o1 = BatchNorm2d(2048)\n        self.dropout = Dropout()\n        if input_size[0] == 112:\n            self.fc = Linear(2048 * 4 * 4, 512)\n        else:\n            self.fc = Linear(2048 * 7 * 7, 512)\n        self.bn_o2 = BatchNorm1d(512)\n\n        initialize_weights(self.modules)\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.bn_o1(x)\n        x = self.dropout(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n        x = self.bn_o2(x)\n\n        return x\n\n\ndef ResNet_50(input_size, **kwargs):\n    \"\"\" Constructs a ResNet-50 model.\n    \"\"\"\n    model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)\n\n    return model\n\n\ndef ResNet_101(input_size, **kwargs):\n    \"\"\" Constructs a ResNet-101 model.\n    \"\"\"\n    model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)\n\n    return model\n\n\ndef ResNet_152(input_size, **kwargs):\n    \"\"\" Constructs a ResNet-152 model.\n    \"\"\"\n    model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)\n\n    return model\n"
  },
  {
    "path": "uniportrait/curricular_face/inference.py",
    "content": "import glob\nimport os\n\nimport cv2\nimport numpy as np\nimport torch\nfrom tqdm.auto import tqdm\n\nfrom .backbone import get_model\n\n\n@torch.no_grad()\ndef inference(name, weight, src_norm_dir):\n    face_model = get_model(name)([112, 112])\n    face_model.load_state_dict(torch.load(weight, map_location=\"cpu\"))\n    face_model = face_model.to(\"cpu\")\n    face_model.eval()\n\n    id2src_norm = {}\n    for src_id in sorted(list(os.listdir(src_norm_dir))):\n        id2src_norm[src_id] = sorted(list(glob.glob(f\"{os.path.join(src_norm_dir, src_id)}/*\")))\n\n    total_sims = []\n    for id_name in tqdm(id2src_norm):\n        src_face_embeddings = []\n        for src_img_path in id2src_norm[id_name]:\n            src_img = cv2.imread(src_img_path)\n            src_img = cv2.resize(src_img, (112, 112))\n            src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)\n            src_img = np.transpose(src_img, (2, 0, 1))\n            src_img = torch.from_numpy(src_img).unsqueeze(0).float()\n            src_img.div_(255).sub_(0.5).div_(0.5)\n            embedding = face_model(src_img).detach().cpu().numpy()[0]\n            embedding = embedding / np.linalg.norm(embedding)\n            src_face_embeddings.append(embedding)  # 512\n\n        num = len(src_face_embeddings)\n        src_face_embeddings = np.stack(src_face_embeddings)  # n, 512\n        sim = src_face_embeddings @ src_face_embeddings.T  # n, n\n        mean_sim = (np.sum(sim) - num * 1.0) / ((num - 1) * num)\n        print(f\"{id_name}: {mean_sim}\")\n        total_sims.append(mean_sim)\n\n    return np.mean(total_sims)\n\n\nif __name__ == \"__main__\":\n    name = 'IR_101'\n    weight = \"models/glint360k_curricular_face_r101_backbone.bin\"\n    src_norm_dir = \"/disk1/hejunjie.hjj/data/normface-AFD-id-20\"\n    mean_sim = inference(name, weight, src_norm_dir)\n    print(f\"total: {mean_sim:.4f}\")  # total: 0.6299\n"
  },
  {
    "path": "uniportrait/inversion.py",
    "content": "# modified from https://github.com/google/style-aligned/blob/main/inversion.py\n\nfrom __future__ import annotations\n\nfrom typing import Callable\n\nimport numpy as np\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom tqdm import tqdm\n\nT = torch.Tensor\nInversionCallback = Callable[[StableDiffusionPipeline, int, T, dict[str, T]], dict[str, T]]\n\n\ndef _encode_text_with_negative(model: StableDiffusionPipeline, prompt: str) -> tuple[dict[str, T], T]:\n    device = model._execution_device\n    prompt_embeds = model._encode_prompt(\n        prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True,\n        negative_prompt=\"\")\n    return prompt_embeds\n\n\ndef _encode_image(model: StableDiffusionPipeline, image: np.ndarray) -> T:\n    model.vae.to(dtype=torch.float32)\n    image = torch.from_numpy(image).float() / 255.\n    image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)\n    latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor\n    model.vae.to(dtype=torch.float16)\n    return latent\n\n\ndef _next_step(model: StableDiffusionPipeline, model_output: T, timestep: int, sample: T) -> T:\n    timestep, next_timestep = min(\n        timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep\n    alpha_prod_t = model.scheduler.alphas_cumprod[\n        int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod\n    alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]\n    beta_prod_t = 1 - alpha_prod_t\n    next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5\n    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output\n    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction\n    return next_sample\n\n\ndef _get_noise_pred(model: StableDiffusionPipeline, latent: T, t: T, context: T, guidance_scale: float):\n    latents_input = torch.cat([latent] * 2)\n    noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)[\"sample\"]\n    noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)\n    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)\n    # latents = next_step(model, noise_pred, t, latent)\n    return noise_pred\n\n\ndef _ddim_loop(model: StableDiffusionPipeline, z0, prompt, guidance_scale) -> T:\n    all_latent = [z0]\n    text_embedding = _encode_text_with_negative(model, prompt)\n    image_embedding = torch.zeros_like(text_embedding[:, :1]).repeat(1, 4, 1)  # for ip embedding\n    text_embedding = torch.cat([text_embedding, image_embedding], dim=1)\n    latent = z0.clone().detach().half()\n    for i in tqdm(range(model.scheduler.num_inference_steps)):\n        t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]\n        noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale)\n        latent = _next_step(model, noise_pred, t, latent)\n        all_latent.append(latent)\n    return torch.cat(all_latent).flip(0)\n\n\ndef make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:\n    def callback_on_step_end(pipeline: StableDiffusionPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[\n        str, T]:\n        latents = callback_kwargs['latents']\n        latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)\n        return {'latents': latents}\n\n    return zts[offset], callback_on_step_end\n\n\n@torch.no_grad()\ndef ddim_inversion(model: StableDiffusionPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int,\n                   guidance_scale, ) -> T:\n    z0 = _encode_image(model, x0)\n    model.scheduler.set_timesteps(num_inference_steps, device=z0.device)\n    zs = _ddim_loop(model, z0, prompt, guidance_scale)\n    return zs\n"
  },
  {
    "path": "uniportrait/resampler.py",
    "content": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py\n\nimport math\n\nimport torch\nimport torch.nn as nn\n\n\n# FFN\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head ** -0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents, attention_mask=None):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latents (torch.Tensor): latent features\n                shape (b, n2, D)\n            attention_mask (torch.Tensor): attention mask\n                shape (b, n1, 1)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        if attention_mask is not None:\n            attention_mask = attention_mask.transpose(1, 2)  # (b, 1, n1)\n            attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :, :1]).repeat(1, 1, l)],\n                                       dim=2)  # b, 1, n1+n2\n            attention_mask = (attention_mask - 1.) * 100.  # 0 means kept and -100 means dropped\n            attention_mask = attention_mask.unsqueeze(1)\n            weight = weight + attention_mask  # b, h, n2, n1+n2\n\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass UniPortraitFaceIDResampler(torch.nn.Module):\n    def __init__(\n            self,\n            intrinsic_id_embedding_dim=512,\n            structure_embedding_dim=64 + 128 + 256 + 1280,\n            num_tokens=16,\n            depth=6,\n            dim=768,\n            dim_head=64,\n            heads=12,\n            ff_mult=4,\n            output_dim=768,\n    ):\n        super().__init__()\n\n        self.latents = torch.nn.Parameter(torch.randn(1, num_tokens, dim) / dim ** 0.5)\n\n        self.proj_id = torch.nn.Sequential(\n            torch.nn.Linear(intrinsic_id_embedding_dim, intrinsic_id_embedding_dim * 2),\n            torch.nn.GELU(),\n            torch.nn.Linear(intrinsic_id_embedding_dim * 2, dim),\n        )\n        self.proj_clip = torch.nn.Sequential(\n            torch.nn.Linear(structure_embedding_dim, structure_embedding_dim * 2),\n            torch.nn.GELU(),\n            torch.nn.Linear(structure_embedding_dim * 2, dim),\n        )\n\n        self.layers = torch.nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                torch.nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        FeedForward(dim=dim, mult=ff_mult),\n                    ]\n                )\n            )\n\n        self.proj_out = torch.nn.Linear(dim, output_dim)\n        self.norm_out = torch.nn.LayerNorm(output_dim)\n\n    def forward(\n            self,\n            intrinsic_id_embeds,\n            structure_embeds,\n            structure_scale=1.0,\n            intrinsic_id_attention_mask=None,\n            structure_attention_mask=None\n    ):\n\n        latents = self.latents.repeat(intrinsic_id_embeds.size(0), 1, 1)\n\n        intrinsic_id_embeds = self.proj_id(intrinsic_id_embeds)\n        structure_embeds = self.proj_clip(structure_embeds)\n\n        for attn1, attn2, ff in self.layers:\n            latents = attn1(intrinsic_id_embeds, latents, intrinsic_id_attention_mask) + latents\n            latents = structure_scale * attn2(structure_embeds, latents, structure_attention_mask) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n"
  },
  {
    "path": "uniportrait/uniportrait_attention_processor.py",
    "content": "# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers.models.lora import LoRALinearLayer\n\n\nclass AttentionArgs(object):\n    def __init__(self) -> None:\n        # ip condition\n        self.ip_scale = 0.0\n        self.ip_mask = None  # ip attention mask\n\n        # faceid condition\n        self.lora_scale = 0.0  # lora for single faceid\n        self.multi_id_lora_scale = 0.0  # lora for multiple faceids\n        self.faceid_scale = 0.0\n        self.num_faceids = 0\n        self.faceid_mask = None  # faceid attention mask; if not None, it will override the routing map\n\n        # style aligned\n        self.enable_share_attn: bool = False\n        self.adain_queries_and_keys: bool = False\n        self.shared_score_scale: float = 1.0\n        self.shared_score_shift: float = 0.0\n\n    def reset(self):\n        # ip condition\n        self.ip_scale = 0.0\n        self.ip_mask = None  # ip attention mask\n\n        # faceid condition\n        self.lora_scale = 0.0  # lora for single faceid\n        self.multi_id_lora_scale = 0.0  # lora for multiple faceids\n        self.faceid_scale = 0.0\n        self.num_faceids = 0\n        self.faceid_mask = None  # faceid attention mask; if not None, it will override the routing map\n\n        # style aligned\n        self.enable_share_attn: bool = False\n        self.adain_queries_and_keys: bool = False\n        self.shared_score_scale: float = 1.0\n        self.shared_score_shift: float = 0.0\n\n    def __repr__(self):\n        indent_str = '    '\n        s = f\",\\n{indent_str}\".join(f\"{attr}={value}\" for attr, value in vars(self).items())\n        return self.__class__.__name__ + '(' + f'\\n{indent_str}' + s + ')'\n\n\nattn_args = AttentionArgs()\n\n\ndef expand_first(feat, scale=1., ):\n    b = feat.shape[0]\n    feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)\n    if scale == 1:\n        feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])\n    else:\n        feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)\n        feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)\n    return feat_style.reshape(*feat.shape)\n\n\ndef concat_first(feat, dim=2, scale=1.):\n    feat_style = expand_first(feat, scale=scale)\n    return torch.cat((feat, feat_style), dim=dim)\n\n\ndef calc_mean_std(feat, eps: float = 1e-5):\n    feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()\n    feat_mean = feat.mean(dim=-2, keepdims=True)\n    return feat_mean, feat_std\n\n\ndef adain(feat):\n    feat_mean, feat_std = calc_mean_std(feat)\n    feat_style_mean = expand_first(feat_mean)\n    feat_style_std = expand_first(feat_std)\n    feat = (feat - feat_mean) / feat_std\n    feat = feat * feat_style_std + feat_style_mean\n    return feat\n\n\nclass UniPortraitLoRAAttnProcessor2_0(nn.Module):\n\n    def __init__(\n            self,\n            hidden_size=None,\n            cross_attention_dim=None,\n            rank=128,\n            network_alpha=None,\n    ):\n        super().__init__()\n\n        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n\n        self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n\n    def __call__(\n            self,\n            attn,\n            hidden_states,\n            encoder_hidden_states=None,\n            attention_mask=None,\n            temb=None,\n            *args,\n            **kwargs,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n        if attn_args.lora_scale > 0.0:\n            query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)\n            key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)\n            value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)\n        elif attn_args.multi_id_lora_scale > 0.0:\n            query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)\n            key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)\n            value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn_args.enable_share_attn:\n            if attn_args.adain_queries_and_keys:\n                query = adain(query)\n                key = adain(key)\n            key = concat_first(key, -2, scale=attn_args.shared_score_scale)\n            value = concat_first(value, -2)\n            if attn_args.shared_score_shift != 0:\n                attention_mask = torch.zeros_like(key[:, :, :, :1]).transpose(-1, -2)  # b, h, 1, k\n                attention_mask[:, :, :, query.shape[2]:] += attn_args.shared_score_shift\n                hidden_states = F.scaled_dot_product_attention(\n                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale\n                )\n            else:\n                hidden_states = F.scaled_dot_product_attention(\n                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale\n                )\n        else:\n            hidden_states = F.scaled_dot_product_attention(\n                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale\n            )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        output_hidden_states = attn.to_out[0](hidden_states)\n        if attn_args.lora_scale > 0.0:\n            output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)\n        elif attn_args.multi_id_lora_scale > 0.0:\n            output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(\n                hidden_states)\n        hidden_states = output_hidden_states\n\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass UniPortraitLoRAIPAttnProcessor2_0(nn.Module):\n\n    def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None,\n                 num_ip_tokens=4, num_faceid_tokens=16):\n        super().__init__()\n\n        self.num_ip_tokens = num_ip_tokens\n        self.num_faceid_tokens = num_faceid_tokens\n\n        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n        self.to_k_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_faceid = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n        self.to_q_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.to_k_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_multi_id_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_out_multi_id_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n\n        self.to_q_router = nn.Sequential(\n            nn.Linear(hidden_size, hidden_size * 2),\n            nn.GELU(),\n            nn.Linear(hidden_size * 2, hidden_size, bias=False),\n        )\n        self.to_k_router = nn.Sequential(\n            nn.Linear(cross_attention_dim or hidden_size, (cross_attention_dim or hidden_size) * 2),\n            nn.GELU(),\n            nn.Linear((cross_attention_dim or hidden_size) * 2, hidden_size, bias=False),\n        )\n        self.aggr_router = nn.Linear(num_faceid_tokens, 1)\n\n    def __call__(\n            self,\n            attn,\n            hidden_states,\n            encoder_hidden_states=None,\n            attention_mask=None,\n            temb=None,\n            *args,\n            **kwargs,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # split hidden states\n            faceid_end = encoder_hidden_states.shape[1]\n            ip_end = faceid_end - self.num_faceid_tokens * attn_args.num_faceids\n            text_end = ip_end - self.num_ip_tokens\n\n            prompt_hidden_states = encoder_hidden_states[:, :text_end]\n            ip_hidden_states = encoder_hidden_states[:, text_end: ip_end]\n            faceid_hidden_states = encoder_hidden_states[:, ip_end: faceid_end]\n\n            encoder_hidden_states = prompt_hidden_states\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        # for router\n        if attn_args.num_faceids > 1:\n            router_query = self.to_q_router(hidden_states)  # bs, s*s, dim\n            router_hidden_states = faceid_hidden_states.reshape(batch_size, attn_args.num_faceids,\n                                                                self.num_faceid_tokens, -1)  # bs, num, id_tokens, d\n            router_hidden_states = self.aggr_router(router_hidden_states.transpose(-1, -2)).squeeze(-1)  # bs, num, d\n            router_key = self.to_k_router(router_hidden_states)  # bs, num, dim\n            router_logits = torch.bmm(router_query, router_key.transpose(-1, -2))  # bs, s*s, num\n            index = router_logits.max(dim=-1, keepdim=True)[1]\n            routing_map = torch.zeros_like(router_logits).scatter_(-1, index, 1.0)\n            routing_map = routing_map.transpose(1, 2).unsqueeze(-1)  # bs, num, s*s, 1\n        else:\n            routing_map = hidden_states.new_ones(size=(1, 1, hidden_states.shape[1], 1))\n\n        # for text\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n        if attn_args.lora_scale > 0.0:\n            query = query + attn_args.lora_scale * self.to_q_lora(hidden_states)\n            key = key + attn_args.lora_scale * self.to_k_lora(encoder_hidden_states)\n            value = value + attn_args.lora_scale * self.to_v_lora(encoder_hidden_states)\n        elif attn_args.multi_id_lora_scale > 0.0:\n            query = query + attn_args.multi_id_lora_scale * self.to_q_multi_id_lora(hidden_states)\n            key = key + attn_args.multi_id_lora_scale * self.to_k_multi_id_lora(encoder_hidden_states)\n            value = value + attn_args.multi_id_lora_scale * self.to_v_multi_id_lora(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale\n        )\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # for ip-adapter\n        if attn_args.ip_scale > 0.0:\n            ip_key = self.to_k_ip(ip_hidden_states)\n            ip_value = self.to_v_ip(ip_hidden_states)\n\n            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            ip_hidden_states = F.scaled_dot_product_attention(\n                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale\n            )\n            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n            ip_hidden_states = ip_hidden_states.to(query.dtype)\n\n            if attn_args.ip_mask is not None:\n                ip_mask = attn_args.ip_mask\n                h, w = ip_mask.shape[-2:]\n                ratio = (h * w / query.shape[2]) ** 0.5\n                ip_mask = torch.nn.functional.interpolate(ip_mask, scale_factor=1 / ratio,\n                                                          mode='nearest').reshape(\n                    [1, -1, 1])\n                ip_hidden_states = ip_hidden_states * ip_mask\n\n            if attn_args.enable_share_attn:\n                ip_hidden_states[0] = 0.\n                ip_hidden_states[batch_size // 2] = 0.\n        else:\n            ip_hidden_states = torch.zeros_like(hidden_states)\n\n        # for faceid-adapter\n        if attn_args.faceid_scale > 0.0:\n            faceid_key = self.to_k_faceid(faceid_hidden_states)\n            faceid_value = self.to_v_faceid(faceid_hidden_states)\n\n            faceid_query = query[:, None].expand(-1, attn_args.num_faceids, -1, -1,\n                                                 -1)  # 2*bs, num, heads, s*s, dim/heads\n            faceid_key = faceid_key.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,\n                                         head_dim).transpose(2, 3)\n            faceid_value = faceid_value.view(batch_size, attn_args.num_faceids, self.num_faceid_tokens, attn.heads,\n                                             head_dim).transpose(2, 3)\n\n            faceid_hidden_states = F.scaled_dot_product_attention(\n                faceid_query, faceid_key, faceid_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale\n            )  # 2*bs, num, heads, s*s, dim/heads\n\n            faceid_hidden_states = faceid_hidden_states.transpose(2, 3).reshape(batch_size, attn_args.num_faceids, -1,\n                                                                                attn.heads * head_dim)\n            faceid_hidden_states = faceid_hidden_states.to(query.dtype)  # 2*bs, num, s*s, dim\n\n            if attn_args.faceid_mask is not None:\n                faceid_mask = attn_args.faceid_mask  # 1, num, h, w\n                h, w = faceid_mask.shape[-2:]\n                ratio = (h * w / query.shape[2]) ** 0.5\n                faceid_mask = F.interpolate(faceid_mask, scale_factor=1 / ratio,\n                                            mode='bilinear').flatten(2).unsqueeze(-1)  # 1, num, s*s, 1\n                faceid_mask = faceid_mask / faceid_mask.sum(1, keepdim=True).clip(min=1e-3)  # 1, num, s*s, 1\n                faceid_hidden_states = (faceid_mask * faceid_hidden_states).sum(1)  # 2*bs, s*s, dim\n            else:\n                faceid_hidden_states = (routing_map * faceid_hidden_states).sum(1)  # 2*bs, s*s, dim\n\n            if attn_args.enable_share_attn:\n                faceid_hidden_states[0] = 0.\n                faceid_hidden_states[batch_size // 2] = 0.\n        else:\n            faceid_hidden_states = torch.zeros_like(hidden_states)\n\n        hidden_states = hidden_states + \\\n                        attn_args.ip_scale * ip_hidden_states + \\\n                        attn_args.faceid_scale * faceid_hidden_states\n\n        # linear proj\n        output_hidden_states = attn.to_out[0](hidden_states)\n        if attn_args.lora_scale > 0.0:\n            output_hidden_states = output_hidden_states + attn_args.lora_scale * self.to_out_lora(hidden_states)\n        elif attn_args.multi_id_lora_scale > 0.0:\n            output_hidden_states = output_hidden_states + attn_args.multi_id_lora_scale * self.to_out_multi_id_lora(\n                hidden_states)\n        hidden_states = output_hidden_states\n\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\n# for controlnet\nclass UniPortraitCNAttnProcessor2_0:\n    def __init__(self, num_ip_tokens=4, num_faceid_tokens=16):\n\n        self.num_ip_tokens = num_ip_tokens\n        self.num_faceid_tokens = num_faceid_tokens\n\n    def __call__(\n            self,\n            attn,\n            hidden_states,\n            encoder_hidden_states=None,\n            attention_mask=None,\n            temb=None,\n            *args,\n            **kwargs,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            text_end = encoder_hidden_states.shape[1] - self.num_faceid_tokens * attn_args.num_faceids \\\n                       - self.num_ip_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :text_end]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale\n        )\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n"
  },
  {
    "path": "uniportrait/uniportrait_pipeline.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom diffusers import ControlNetModel\nfrom diffusers.pipelines.controlnet import MultiControlNetModel\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom .curricular_face.backbone import get_model\nfrom .resampler import UniPortraitFaceIDResampler\nfrom .uniportrait_attention_processor import UniPortraitCNAttnProcessor2_0 as UniPortraitCNAttnProcessor\nfrom .uniportrait_attention_processor import UniPortraitLoRAAttnProcessor2_0 as UniPortraitLoRAAttnProcessor\nfrom .uniportrait_attention_processor import UniPortraitLoRAIPAttnProcessor2_0 as UniPortraitLoRAIPAttnProcessor\n\n\nclass ImageProjModel(nn.Module):\n    \"\"\"Projection Model\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.cross_attention_dim = cross_attention_dim\n        self.clip_extra_context_tokens = clip_extra_context_tokens\n        self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = nn.LayerNorm(cross_attention_dim)\n\n    def forward(self, image_embeds):\n        embeds = image_embeds  # b, c\n        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens,\n                                                              self.cross_attention_dim)\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\nclass UniPortraitPipeline:\n\n    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt=None, face_backbone_ckpt=None, uniportrait_faceid_ckpt=None,\n                 uniportrait_router_ckpt=None, num_ip_tokens=4, num_faceid_tokens=16,\n                 lora_rank=128, device=torch.device(\"cuda\"), torch_dtype=torch.float16):\n\n        self.image_encoder_path = image_encoder_path\n        self.ip_ckpt = ip_ckpt\n        self.uniportrait_faceid_ckpt = uniportrait_faceid_ckpt\n        self.uniportrait_router_ckpt = uniportrait_router_ckpt\n\n        self.num_ip_tokens = num_ip_tokens\n        self.num_faceid_tokens = num_faceid_tokens\n        self.lora_rank = lora_rank\n\n        self.device = device\n        self.torch_dtype = torch_dtype\n\n        self.pipe = sd_pipe.to(self.device)\n\n        # load clip image encoder\n        self.clip_image_processor = CLIPImageProcessor(size={\"shortest_edge\": 224}, do_center_crop=False,\n                                                       use_square_size=True)\n        self.clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(\n            self.device, dtype=self.torch_dtype)\n        # load face backbone\n        self.facerecog_model = get_model(\"IR_101\")([112, 112])\n        self.facerecog_model.load_state_dict(torch.load(face_backbone_ckpt, map_location=\"cpu\"))\n        self.facerecog_model = self.facerecog_model.to(self.device, dtype=torch_dtype)\n        self.facerecog_model.eval()\n        # image proj model\n        self.image_proj_model = self.init_image_proj()\n        # faceid proj model\n        self.faceid_proj_model = self.init_faceid_proj()\n        # set uniportrait and ip adapter\n        self.set_uniportrait_and_ip_adapter()\n        # load uniportrait and ip adapter\n        self.load_uniportrait_and_ip_adapter()\n\n    def init_image_proj(self):\n        image_proj_model = ImageProjModel(\n            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,\n            clip_embeddings_dim=self.clip_image_encoder.config.projection_dim,\n            clip_extra_context_tokens=self.num_ip_tokens,\n        ).to(self.device, dtype=self.torch_dtype)\n        return image_proj_model\n\n    def init_faceid_proj(self):\n        faceid_proj_model = UniPortraitFaceIDResampler(\n            intrinsic_id_embedding_dim=512,\n            structure_embedding_dim=64 + 128 + 256 + self.clip_image_encoder.config.hidden_size,\n            num_tokens=16, depth=6,\n            dim=self.pipe.unet.config.cross_attention_dim, dim_head=64,\n            heads=12, ff_mult=4,\n            output_dim=self.pipe.unet.config.cross_attention_dim\n        ).to(self.device, dtype=self.torch_dtype)\n        return faceid_proj_model\n\n    def set_uniportrait_and_ip_adapter(self):\n        unet = self.pipe.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None:\n                attn_procs[name] = UniPortraitLoRAAttnProcessor(\n                    hidden_size=hidden_size,\n                    cross_attention_dim=cross_attention_dim,\n                    rank=self.lora_rank,\n                ).to(self.device, dtype=self.torch_dtype).eval()\n            else:\n                attn_procs[name] = UniPortraitLoRAIPAttnProcessor(\n                    hidden_size=hidden_size,\n                    cross_attention_dim=cross_attention_dim,\n                    rank=self.lora_rank,\n                    num_ip_tokens=self.num_ip_tokens,\n                    num_faceid_tokens=self.num_faceid_tokens,\n                ).to(self.device, dtype=self.torch_dtype).eval()\n        unet.set_attn_processor(attn_procs)\n        if hasattr(self.pipe, \"controlnet\"):\n            if isinstance(self.pipe.controlnet, ControlNetModel):\n                self.pipe.controlnet.set_attn_processor(\n                    UniPortraitCNAttnProcessor(\n                        num_ip_tokens=self.num_ip_tokens,\n                        num_faceid_tokens=self.num_faceid_tokens,\n                    )\n                )\n            elif isinstance(self.pipe.controlnet, MultiControlNetModel):\n                for module in self.pipe.controlnet.nets:\n                    module.set_attn_processor(\n                        UniPortraitCNAttnProcessor(\n                            num_ip_tokens=self.num_ip_tokens,\n                            num_faceid_tokens=self.num_faceid_tokens,\n                        )\n                    )\n            else:\n                raise ValueError\n\n    def load_uniportrait_and_ip_adapter(self):\n        if self.ip_ckpt:\n            print(f\"loading from {self.ip_ckpt}...\")\n            state_dict = torch.load(self.ip_ckpt, map_location=\"cpu\")\n            self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=False)\n            ip_layers = nn.ModuleList(self.pipe.unet.attn_processors.values())\n            ip_layers.load_state_dict(state_dict[\"ip_adapter\"], strict=False)\n\n        if self.uniportrait_faceid_ckpt:\n            print(f\"loading from {self.uniportrait_faceid_ckpt}...\")\n            state_dict = torch.load(self.uniportrait_faceid_ckpt, map_location=\"cpu\")\n            self.faceid_proj_model.load_state_dict(state_dict[\"faceid_proj\"], strict=True)\n            ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())\n            ip_layers.load_state_dict(state_dict[\"faceid_adapter\"], strict=False)\n\n            if self.uniportrait_router_ckpt:\n                print(f\"loading from {self.uniportrait_router_ckpt}...\")\n                state_dict = torch.load(self.uniportrait_router_ckpt, map_location=\"cpu\")\n                router_state_dict = {}\n                for k, v in state_dict[\"faceid_adapter\"].items():\n                    if \"lora.\" in k:\n                        router_state_dict[k.replace(\"lora.\", \"multi_id_lora.\")] = v\n                    elif \"router.\" in k:\n                        router_state_dict[k] = v\n                ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())\n                ip_layers.load_state_dict(router_state_dict, strict=False)\n\n    @torch.inference_mode()\n    def get_ip_embeds(self, pil_ip_image):\n        ip_image = self.clip_image_processor(images=pil_ip_image, return_tensors=\"pt\").pixel_values\n        ip_image = ip_image.to(self.device, dtype=self.torch_dtype)  # (b, 3, 224, 224), values being normalized\n        ip_embeds = self.clip_image_encoder(ip_image).image_embeds\n        ip_prompt_embeds = self.image_proj_model(ip_embeds)\n        uncond_ip_prompt_embeds = self.image_proj_model(torch.zeros_like(ip_embeds))\n        return ip_prompt_embeds, uncond_ip_prompt_embeds\n\n    @torch.inference_mode()\n    def get_single_faceid_embeds(self, pil_face_images, face_structure_scale):\n        face_clip_image = self.clip_image_processor(images=pil_face_images, return_tensors=\"pt\").pixel_values\n        face_clip_image = face_clip_image.to(self.device, dtype=self.torch_dtype)  # (b, 3, 224, 224)\n        face_clip_embeds = self.clip_image_encoder(\n            face_clip_image, output_hidden_states=True).hidden_states[-2][:, 1:]  # b, 256, 1280\n\n        OPENAI_CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=self.device,\n                                        dtype=self.torch_dtype).reshape(-1, 1, 1)\n        OPENAI_CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=self.device,\n                                       dtype=self.torch_dtype).reshape(-1, 1, 1)\n        facerecog_image = face_clip_image * OPENAI_CLIP_STD + OPENAI_CLIP_MEAN  # [0, 1]\n        facerecog_image = torch.clamp((facerecog_image - 0.5) / 0.5, -1, 1)  # [-1, 1]\n        facerecog_image = F.interpolate(facerecog_image, size=(112, 112), mode=\"bilinear\", align_corners=False)\n        facerecog_embeds = self.facerecog_model(facerecog_image, return_mid_feats=True)[1]\n\n        face_intrinsic_id_embeds = facerecog_embeds[-1]  # (b, 512, 7, 7)\n        face_intrinsic_id_embeds = face_intrinsic_id_embeds.flatten(2).permute(0, 2, 1)  # b, 49, 512\n\n        facerecog_structure_embeds = facerecog_embeds[:-1]  # (b, 64, 56, 56), (b, 128, 28, 28), (b, 256, 14, 14)\n        facerecog_structure_embeds = torch.cat([\n            F.interpolate(feat, size=(16, 16), mode=\"bilinear\", align_corners=False)\n            for feat in facerecog_structure_embeds], dim=1)  # b, 448, 16, 16\n        facerecog_structure_embeds = facerecog_structure_embeds.flatten(2).permute(0, 2, 1)  # b, 256, 448\n        face_structure_embeds = torch.cat([facerecog_structure_embeds, face_clip_embeds], dim=-1)  # b, 256, 1728\n\n        uncond_face_clip_embeds = self.clip_image_encoder(\n            torch.zeros_like(face_clip_image[:1]), output_hidden_states=True).hidden_states[-2][:, 1:]  # 1, 256, 1280\n        uncond_face_structure_embeds = torch.cat(\n            [torch.zeros_like(facerecog_structure_embeds[:1]), uncond_face_clip_embeds], dim=-1)  # 1, 256, 1728\n\n        faceid_prompt_embeds = self.faceid_proj_model(\n            face_intrinsic_id_embeds.flatten(0, 1).unsqueeze(0),\n            face_structure_embeds.flatten(0, 1).unsqueeze(0),\n            structure_scale=face_structure_scale,\n        )  # [b, 16, 768]\n\n        uncond_faceid_prompt_embeds = self.faceid_proj_model(\n            torch.zeros_like(face_intrinsic_id_embeds[:1]),\n            uncond_face_structure_embeds,\n            structure_scale=face_structure_scale,\n        )  # [1, 16, 768]\n\n        return faceid_prompt_embeds, uncond_faceid_prompt_embeds\n\n    def generate(\n            self,\n            prompt=None,\n            negative_prompt=None,\n            pil_ip_image=None,\n            cond_faceids=None,\n            face_structure_scale=0.0,\n            seed=-1,\n            guidance_scale=7.5,\n            num_inference_steps=30,\n            zT=None,\n            **kwargs,\n    ):\n        \"\"\"\n        Args:\n            prompt:\n            negative_prompt:\n            pil_ip_image:\n            cond_faceids: [\n                {\n                    \"refs\": [PIL.Image] or PIL.Image,\n                    (Optional) \"mix_refs\": [PIL.Image],\n                    (Optional) \"mix_scales\": [float],\n                },\n                ...\n            ]\n            face_structure_scale:\n            seed:\n            guidance_scale:\n            num_inference_steps:\n            zT:\n            **kwargs:\n        Returns:\n        \"\"\"\n\n        if seed is not None:\n            torch.manual_seed(seed)\n            torch.cuda.manual_seed_all(seed)\n\n        with torch.inference_mode():\n            prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(\n                prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt)\n            num_prompts = prompt_embeds.shape[0]\n\n            if pil_ip_image is not None:\n                ip_prompt_embeds, uncond_ip_prompt_embeds = self.get_ip_embeds(pil_ip_image)\n                ip_prompt_embeds = ip_prompt_embeds.repeat(num_prompts, 1, 1)\n                uncond_ip_prompt_embeds = uncond_ip_prompt_embeds.repeat(num_prompts, 1, 1)\n            else:\n                ip_prompt_embeds = uncond_ip_prompt_embeds = \\\n                    torch.zeros_like(prompt_embeds[:, :1]).repeat(1, self.num_ip_tokens, 1)\n\n            prompt_embeds = torch.cat([prompt_embeds, ip_prompt_embeds], dim=1)\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_ip_prompt_embeds], dim=1)\n\n            if cond_faceids and len(cond_faceids) > 0:\n                all_faceid_prompt_embeds = []\n                all_uncond_faceid_prompt_embeds = []\n                for curr_faceid_info in cond_faceids:\n                    refs = curr_faceid_info[\"refs\"]\n                    faceid_prompt_embeds, uncond_faceid_prompt_embeds = \\\n                        self.get_single_faceid_embeds(refs, face_structure_scale)\n                    if \"mix_refs\" in curr_faceid_info:\n                        mix_refs = curr_faceid_info[\"mix_refs\"]\n                        mix_scales = curr_faceid_info[\"mix_scales\"]\n\n                        master_face_mix_scale = 1.0 - sum(mix_scales)\n                        faceid_prompt_embeds = faceid_prompt_embeds * master_face_mix_scale\n                        for mix_ref, mix_scale in zip(mix_refs, mix_scales):\n                            faceid_mix_prompt_embeds, _ = self.get_single_faceid_embeds(mix_ref, face_structure_scale)\n                            faceid_prompt_embeds = faceid_prompt_embeds + faceid_mix_prompt_embeds * mix_scale\n\n                    all_faceid_prompt_embeds.append(faceid_prompt_embeds)\n                    all_uncond_faceid_prompt_embeds.append(uncond_faceid_prompt_embeds)\n\n                faceid_prompt_embeds = torch.cat(all_faceid_prompt_embeds, dim=1)\n                uncond_faceid_prompt_embeds = torch.cat(all_uncond_faceid_prompt_embeds, dim=1)\n                faceid_prompt_embeds = faceid_prompt_embeds.repeat(num_prompts, 1, 1)\n                uncond_faceid_prompt_embeds = uncond_faceid_prompt_embeds.repeat(num_prompts, 1, 1)\n\n                prompt_embeds = torch.cat([prompt_embeds, faceid_prompt_embeds], dim=1)\n                negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_faceid_prompt_embeds], dim=1)\n\n        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None\n        if zT is not None:\n            h_, w_ = kwargs[\"image\"][0].shape[-2:]\n            latents = torch.randn(num_prompts, 4, h_ // 8, w_ // 8, device=self.device, generator=generator,\n                                  dtype=self.pipe.unet.dtype)\n            latents[0] = zT\n        else:\n            latents = None\n\n        images = self.pipe(\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            guidance_scale=guidance_scale,\n            num_inference_steps=num_inference_steps,\n            generator=generator,\n            latents=latents,\n            **kwargs,\n        ).images\n\n        return images\n"
  }
]