[
  {
    "path": ".gitignore",
    "content": "*.ckpt\n*.pt\n*.pyc\n*.safetensors\n\n__pycache__/\noutput/\ncheckpoints/\ntrain/\nconfigs/\n\n*.wav\n*.mp3\n*.gif\n*.jpg\n*.png\n*.log\n*.ckpt\n*.json\n\n*.csv\n*.txt\n*.bin\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    # Ruff version.\n    rev: v0.3.5\n    hooks:\n      # Run the linter.\n      - id: ruff\n        args: [ --fix ]\n      # Run the formatter.\n      - id: ruff-format\n  - repo: https://github.com/codespell-project/codespell\n    rev: v2.2.1\n    hooks:\n      - id: codespell\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.3.0\n    hooks:\n      - id: trailing-whitespace\n      - id: check-yaml\n      - id: end-of-file-fixer\n      - id: requirements-txt-fixer\n      - id: fix-encoding-pragma\n        args: [\"--remove\"]\n      - id: mixed-line-ending\n        args: [\"--fix=lf\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "                               Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n<img src='assets/foleycrafter.png' style=\"text-align: center; width: 134px\" >\n</p>\n\n<div align=\"center\">\n\n[![arXiv](https://img.shields.io/badge/arXiv-2407.01494-b31b1b.svg)](https://arxiv.org/abs/2407.01494)\n[![Project Page](https://img.shields.io/badge/FoleyCrafter-Website-green)](https://foleycrafter.github.io)\n<a target=\"_blank\" href=\"https://huggingface.co/spaces/ymzhang319/FoleyCrafter\">\n  <img src=\"https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg\" alt=\"Open in HugginFace\"/>\n</a>\n[![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/ymzhang319/FoleyCrafter)\n[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/ymzhang319/FoleyCrafter)\n\n</div>\n\n# FoleyCrafter\n\nSound effects are the unsung heroes of cinema and gaming, enhancing realism, impact, and emotional depth for an immersive audiovisual experience. **FoleyCrafter** is a video-to-audio generation framework which can produce realistic sound effects semantically relevant and synchronized with videos.\n\n**Your star is our fuel! <img alt=\"\" width=\"30\" src=\"https://camo.githubusercontent.com/2f4f0d02cdf79dc1ff8d2b053b4410b13bc2e39cbc8a96fcdc6f06538a3d6d2b/68747470733a2f2f656d2d636f6e74656e742e7a6f626a2e6e65742f736f757263652f616e696d617465642d6e6f746f2d636f6c6f722d656d6f6a692f3335362f736d696c696e672d666163652d776974682d6865617274735f31663937302e676966\"> We're revving up the engines with it! <img alt=\"\" width=\"30\" src=\"https://camo.githubusercontent.com/028a75f875b8c3aa1b3c80bbf7dd27973c4bb654fffcf0bdc0b6f1b0674ce481/68747470733a2f2f656d2d636f6e74656e742e7a6f626a2e6e65742f736f757263652f74656c656772616d2f3338362f737061726b6c65735f323732382e77656270\">**\n\n\n[FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds]()\n\n[Yiming Zhang](https://github.com/ymzhang0319),\n[Yicheng Gu](https://github.com/VocodexElysium),\n[Yanhong Zeng†](https://zengyh1900.github.io/),\n[Zhening Xing](https://github.com/LeoXing1996/),\n[Yuancheng Wang](https://github.com/HeCheng0625),\n[Zhizheng Wu](https://drwuz.com/),\n[Kai Chen†](https://chenkai.site/)\n\n(†Corresponding Author)\n\n\n## What's New\n- [ ] A more powerful one :stuck_out_tongue_closed_eyes: .\n- [ ] Release training code.\n- [x] `2024/07/01` Release the model and code of FoleyCrafter.\n\n## Setup\n\n### Prepare Environment\nUse the following command to install dependencies:\n```bash\n# install conda environment\nconda env create -f requirements/environment.yaml\nconda activate foleycrafter\n\n# install GIT LFS for checkpoints download\nconda install git-lfs\ngit lfs install\n```\n\n### Download Checkpoints\nThe checkpoints will be downloaded automatically by running `inference.py`.\n\nYou can also download manually using following commands.\n<li> Download the text-to-audio base model. We use Auffusion</li>\n\n```bash\ngit clone https://huggingface.co/auffusion/auffusion-full-no-adapter checkpoints/auffusion\n```\n\n<li> Download FoleyCrafter</li>\n\n```bash\ngit clone https://huggingface.co/ymzhang319/FoleyCrafter checkpoints/\n```\n\nPut checkpoints as follows:\n```\n└── checkpoints\n    ├── semantic\n    │   ├── semantic_adapter.bin\n    ├── vocoder\n    │   ├── vocoder.pt\n    │   ├── config.json\n    ├── temporal_adapter.ckpt\n    │   │\n    └── timestamp_detector.pth.tar\n```\n\n## Gradio demo\n\nYou can launch the Gradio interface for FoleyCrafter by running the following command:\n\n```bash\npython app.py --share\n```\n\n\n\n## Inference\n### Video To Audio Generation\n```bash\npython inference.py --save_dir=output/sora/\n```\n\nResults:\n<table class='center'>\n<tr>\n  <td><p style=\"text-align: center\">Input Video</p></td>\n  <td><p style=\"text-align: center\">Generated Audio</p></td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309262-d7c89984-c567-4ca7-8e2d-8f49d84bda4a.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122032Z&X-Amz-Expires=300&X-Amz-Signature=5b13f216056dedca2705233038dbb22f73023d2c1deaf3b03972d7b91c1bbab5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n </td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309725-0dfa72a2-1466-46e6-9611-3e1cbff707fe.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122123Z&X-Amz-Expires=300&X-Amz-Signature=314648ed216620b2d926395d34602c70da500eb9e865e839de6907ed1b0d0bd1&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309166-16206bb8-9c5e-4e9d-9d73-bc251e5658fd.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122143Z&X-Amz-Expires=300&X-Amz-Signature=43c3e2c687846eb3ba118237628b747a78c403ea4f21739fe2d423724f7b426c&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309768-90c42af6-0d24-4a05-98d4-64e23467c4bb.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122213Z&X-Amz-Expires=300&X-Amz-Signature=cfed43cd2710bf73b84b6c3ebe8debd1e0b098bdc24a1a14f6531499d01c278e&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309601-e711b7c5-1614-4d39-8b1e-c54e28eec809.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122221Z&X-Amz-Expires=300&X-Amz-Signature=4c4680eb6c541433e4505fb2b5f5a5cc8d3e5708d9f2675a98cbb556cd5d59f5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309802-2db7f130-0c25-45c2-ad4d-bf86c5468b1f.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122243Z&X-Amz-Expires=300&X-Amz-Signature=2c32069318f60f03ee9a3185a7e2833c534ea601a02a5ff025b46c2abbc5b120&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309637-6c2f106d-6b98-41ac-80ba-734636321f8c.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122305Z&X-Amz-Expires=300&X-Amz-Signature=04adb2cd80785a245ce704837ba9932d3646d375c51c53e7b70e6861ec7f6b4a&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309836-77391524-9b31-4602-ad42-0876e0c16794.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122317Z&X-Amz-Expires=300&X-Amz-Signature=3020c12591592a106efcb1aaa22237093737afaa55ede3880d3cbf9cd80b7482&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n</table>\n\n- Temporal Alignment with Visual Cues\n```bash\npython inference.py \\\n--temporal_align \\\n--input=input/avsync \\\n--save_dir=output/avsync/\n```\n\nResults:\n<table class='center'>\n<tr>\n  <td><p style=\"text-align: center\">Ground Truth</p></td>\n  <td><p style=\"text-align: center\">Generated Audio</p></td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310778-bcc0f16d-6d1b-468d-a775-81b8f2d98ea6.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122327Z&X-Amz-Expires=300&X-Amz-Signature=205b8e190a428b3ddee41fe2549080b4f50fd8bb10ef78d650fc05add85ccbab&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310418-8433e05c-8600-4cd6-8a68-ead536159204.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122337Z&X-Amz-Expires=300&X-Amz-Signature=3fc37a305511c1c8b7bdfc9b9b5bd0485fd584400af087939a4c08218ab33538&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310801-3d6fd80d-de6b-4815-ac6a-f81772709e4c.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122349Z&X-Amz-Expires=300&X-Amz-Signature=fc31f139a1f9c7606657fa457f1271a8a44cb39a13454939c030ccdafe2d3068&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310491-dfaf41e7-487e-47ff-8e8a-fe7cb4fb1942.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122356Z&X-Amz-Expires=300&X-Amz-Signature=6353a935194a08bc081fa873e3c6582fb175874d3112f8f8f96614a5e542ef03&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310825-6834f00f-95e8-4a2c-b864-b4fe57801836.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122406Z&X-Amz-Expires=300&X-Amz-Signature=c67b2f113b0db790a8495d1a4ab4c0d230db5f53a9062a497bcca3e57f9600aa&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310543-5a2c363b-623c-4329-be0e-a151e5bb56a6.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122415Z&X-Amz-Expires=300&X-Amz-Signature=d8b0bfc28716e0e03694b3e590aca29450fa7788aa39e748130ee12a15d614e9&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n</table>\n\n### Text-based Video to Audio Generation\n\n- Using Prompt\n\n```bash\n# case1\npython inference.py \\\n--input=input/PromptControl/case1/ \\\n--seed=10201304011203481429 \\\n--save_dir=output/PromptControl/case1/\n\npython inference.py \\\n--input=input/PromptControl/case1/ \\\n--seed=10201304011203481429 \\\n--prompt='noisy, people talking' \\\n--save_dir=output/PromptControl/case1_prompt/\n\n# case2\npython inference.py \\\n--input=input/PromptControl/case2/ \\\n--seed=10021049243103289113 \\\n--save_dir=output/PromptControl/case2/\n\npython inference.py \\\n--input=input/PromptControl/case2/ \\\n--seed=10021049243103289113 \\\n--prompt='seagulls' \\\n--save_dir=output/PromptControl/case2_prompt/\n```\nResults:\n<table class='center'>\n<tr>\n  <td><p style=\"text-align: center\">Generated Audio</p></td>\n  <td><p style=\"text-align: center\">Generated Audio</p></td>\n<tr>\n<tr>\n  <td><p style=\"text-align: center\">Without Prompt</p></td>\n  <td><p style=\"text-align: center\">Prompt: <b>noisy, people talking</b></p></td>\n<tr>\n<tr>\n  <td>\n\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311425-8dd543cb-0df2-441e-b6d0-86048dbeb73d.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122426Z&X-Amz-Expires=300&X-Amz-Signature=b872b8eaf51a5022aee1daf0283d92e53a70e109c6b9f1e6a4da238a3708ea45&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311493-62a08024-581c-4716-a030-aef194beddc5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122439Z&X-Amz-Expires=300&X-Amz-Signature=647eb0dc32bf7c0d739ccbe875826b1a67f54e0dc84e0be70e0e128ae2fdb73d&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td><p style=\"text-align: center\">Without Prompt</p></td>\n  <td><p style=\"text-align: center\">Prompt: <b>seagulls</b></p></td>\n<tr>\n<tr>\n  <td>\n\n\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311538-1f81f91e-efc0-41ed-bdcb-c5c6ff976c5b.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122447Z&X-Amz-Expires=300&X-Amz-Signature=d19761788775893e77e42b9312b4c26cb85aedd7c6dc249eaf68ff1f650e1942&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311595-695668ed-46a1-47b2-b5fd-3aa4286d695e.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122500Z&X-Amz-Expires=300&X-Amz-Signature=bf7d69ab8c74154ee8ac5682f64ce29a310ea2d0365f620893a45899d62a3f80&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n</table>\n\n- Using Negative Prompt\n```bash\n# case 3\npython inference.py \\\n--input=input/PromptControl/case3/ \\\n--seed=10041042941301238011 \\\n--save_dir=output/PromptControl/case3/\n\npython inference.py \\\n--input=input/PromptControl/case3/ \\\n--seed=10041042941301238011 \\\n--nprompt='river flows' \\\n--save_dir=output/PromptControl/case3_nprompt/\n\n# case4\npython inference.py \\\n--input=input/PromptControl/case4/ \\\n--seed=10014024412012338096 \\\n--save_dir=output/PromptControl/case4/\n\npython inference.py \\\n--input=input/PromptControl/case4/ \\\n--seed=10014024412012338096 \\\n--nprompt='noisy, wind noise' \\\n--save_dir=output/PromptControl/case4_nprompt/\n\n```\nResults:\n<table class='center'>\n<tr>\n  <td><p style=\"text-align: center\">Generated Audio</p></td>\n  <td><p style=\"text-align: center\">Generated Audio</p></td>\n<tr>\n<tr>\n  <td><p style=\"text-align: center\">Without Prompt</p></td>\n  <td><p style=\"text-align: center\">Negative Prompt: <b>river flows</b></p></td>\n<tr>\n<tr>\n  <td>\n\n\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311656-cdc69cf1-88f8-4861-b888-bdb82358b9c5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122517Z&X-Amz-Expires=300&X-Amz-Signature=1731e655b2f0bb7f4a7af737ee065a01f98fccb1c54ef48ae775ad65ec67eda5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311702-cd259522-84f4-44cb-862f-c4dcfb57e5c4.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122527Z&X-Amz-Expires=300&X-Amz-Signature=d846ac60df25ff18de1861daeff380b7bc8ca21c04e7dce139ed44abbf9aaa22&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n<tr>\n  <td><p style=\"text-align: center\">Without Prompt</p></td>\n  <td><p style=\"text-align: center\">Negative Prompt: <b>noisy, wind noise</b></p></td>\n<tr>\n<tr>\n  <td>\n\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311785-5ca9c050-a928-4dc2-b620-d843a3ae72f5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122533Z&X-Amz-Expires=300&X-Amz-Signature=151fe4da521f9f48ff245ef5bd7c6964f1dfc652be0fe8de4c151ba59e87d2d6&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n\n</td>\n  <td>\n\nhttps://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311844-28d6abe3-d5a8-4a7f-9f4d-3cc8411affba.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122544Z&X-Amz-Expires=300&X-Amz-Signature=39a259ca76a12a57d47ad74c8cf92af12af9e3db34084f5f82c79b6a62356e9a&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188\n\n</td>\n<tr>\n</table>\n\n### Commandline Usage Parameters\n```console\noptions:\n  -h, --help            show this help message and exit\n  --prompt PROMPT       prompt for audio generation\n  --nprompt NPROMPT     negative prompt for audio generation\n  --seed SEED           ramdom seed\n  --temporal_align TEMPORAL_ALIGN\n                        use temporal adapter or not\n  --temporal_scale TEMPORAL_SCALE\n                        temporal align scale\n  --semantic_scale SEMANTIC_SCALE\n                        visual content scale\n  --input INPUT         input video folder path\n  --ckpt CKPT           checkpoints folder path\n  --save_dir SAVE_DIR   generation result save path\n  --pretrain PRETRAIN   generator checkpoint path\n  --device DEVICE\n```\n\n\n## BibTex\n```\n@misc{zhang2024pia,\n  title={FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds},\n  author={Yiming Zhang, Yicheng Gu, Yanhong Zeng, Zhening Xing, Yuancheng Wang, Zhizheng Wu, Kai Chen},\n  year={2024},\n  eprint={2407.01494},\n  archivePrefix={arXiv},\n  primaryClass={cs.CV}\n}\n```\n\n\n## Contact Us\n\n**Yiming Zhang**: zhangyiming@pjlab.org.cn\n\n**YiCheng Gu**: yichenggu@link.cuhk.edu.cn\n\n**Yanhong Zeng**: zengyanhong@pjlab.org.cn\n\n## LICENSE\nPlease check [LICENSE](./LICENSE) for the part of FoleyCrafter for details.\nIf you are using it for commercial purposes, please check the license of the [Auffusion](https://github.com/happylittlecat2333/Auffusion).\n\n## Acknowledgements\nThe code is built upon [Auffusion](https://github.com/happylittlecat2333/Auffusion), [CondFoleyGen](https://github.com/XYPB/CondFoleyGen) and [SpecVQGAN](https://github.com/v-iashin/SpecVQGAN).\n\nWe recommend a toolkit for Audio, Music, and Speech Generation [Amphion](https://github.com/open-mmlab/Amphion) :gift_heart:.\n"
  },
  {
    "path": "app.py",
    "content": "import os\nimport os.path as osp\nimport random\nfrom argparse import ArgumentParser\nfrom datetime import datetime\n\nimport gradio as gr\nimport soundfile as sf\nimport torch\nimport torchvision\nfrom huggingface_hub import snapshot_download\nfrom moviepy.editor import AudioFileClip, VideoFileClip\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler\nfrom foleycrafter.models.onset import torch_utils\nfrom foleycrafter.models.time_detector.model import VideoOnsetNet\nfrom foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram\nfrom foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy\n\n\nos.environ[\"GRADIO_TEMP_DIR\"] = \"./tmp\"\n\nsample_idx = 0\nscheduler_dict = {\n    \"DDIM\": DDIMScheduler,\n    \"Euler\": EulerDiscreteScheduler,\n    \"PNDM\": PNDMScheduler,\n}\n\ncss = \"\"\"\n.toolbutton {\n    margin-buttom: 0em 0em 0em 0em;\n    max-width: 2.5em;\n    min-width: 2.5em !important;\n    height: 2.5em;\n}\n\"\"\"\n\nparser = ArgumentParser()\nparser.add_argument(\"--config\", type=str, default=\"example/config/base.yaml\")\nparser.add_argument(\"--server-name\", type=str, default=\"0.0.0.0\")\nparser.add_argument(\"--port\", type=int, default=7860)\nparser.add_argument(\"--share\", type=bool, default=False)\n\nparser.add_argument(\"--save-path\", default=\"samples\")\nparser.add_argument(\"--ckpt\", type=str, default=\"checkpoints/\")\n\nargs = parser.parse_args()\n\n\nN_PROMPT = \"\"\n\n\nclass FoleyController:\n    def __init__(self):\n        # config dirs\n        self.basedir = os.getcwd()\n        self.model_dir = os.path.join(self.basedir, args.ckpt)\n        self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime(\"Gradio-%Y-%m-%dT%H-%M-%S\"))\n        self.savedir_sample = os.path.join(self.savedir, \"sample\")\n        os.makedirs(self.savedir, exist_ok=True)\n\n        self.pipeline = None\n\n        self.loaded = False\n\n        self.load_model()\n\n    def load_model(self):\n        gr.Info(\"Start Load Models...\")\n        print(\"Start Load Models...\")\n\n        # download ckpt\n        pretrained_model_name_or_path = \"auffusion/auffusion-full-no-adapter\"\n        if not os.path.isdir(pretrained_model_name_or_path):\n            pretrained_model_name_or_path = snapshot_download(\n                pretrained_model_name_or_path, local_dir=osp.join(self.model_dir, \"auffusion\")\n            )\n\n        fc_ckpt = \"ymzhang319/FoleyCrafter\"\n        if not os.path.isdir(fc_ckpt):\n            fc_ckpt = snapshot_download(fc_ckpt, local_dir=self.model_dir)\n\n        # set model config\n        temporal_ckpt_path = osp.join(self.model_dir, \"temporal_adapter.ckpt\")\n\n        # load vocoder\n        vocoder_config_path = osp.join(self.model_dir, \"auffusion\")\n        self.vocoder = Generator.from_pretrained(vocoder_config_path, subfolder=\"vocoder\")\n\n        # load time detector\n        time_detector_ckpt = osp.join(osp.join(self.model_dir, \"timestamp_detector.pth.tar\"))\n        time_detector = VideoOnsetNet(False)\n        self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)\n\n        self.pipeline = build_foleycrafter()\n        ckpt = torch.load(temporal_ckpt_path)\n\n        # load temporal adapter\n        if \"state_dict\" in ckpt.keys():\n            ckpt = ckpt[\"state_dict\"]\n        load_gligen_ckpt = {}\n        for key, value in ckpt.items():\n            if key.startswith(\"module.\"):\n                load_gligen_ckpt[key[len(\"module.\") :]] = value\n            else:\n                load_gligen_ckpt[key] = value\n        m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)\n        print(f\"### Control Net missing keys: {len(m)}; \\n### unexpected keys: {len(u)};\")\n\n        self.image_processor = CLIPImageProcessor()\n        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n            \"h94/IP-Adapter\", subfolder=\"models/image_encoder\"\n        )\n\n        self.pipeline.load_ip_adapter(\n            fc_ckpt, subfolder=\"semantic\", weight_name=\"semantic_adapter.bin\", image_encoder_folder=None\n        )\n\n        gr.Info(\"Load Finish!\")\n        print(\"Load Finish!\")\n        self.loaded = True\n\n        return \"Load\"\n\n    def foley(\n        self,\n        input_video,\n        prompt_textbox,\n        negative_prompt_textbox,\n        ip_adapter_scale,\n        temporal_scale,\n        sampler_dropdown,\n        sample_step_slider,\n        cfg_scale_slider,\n        seed_textbox,\n    ):\n        device = \"cuda\"\n        # move to gpu\n        self.time_detector = controller.time_detector.to(device)\n        self.pipeline = controller.pipeline.to(device)\n        self.vocoder = controller.vocoder.to(device)\n        self.image_encoder = controller.image_encoder.to(device)\n        vision_transform_list = [\n            torchvision.transforms.Resize((128, 128)),\n            torchvision.transforms.CenterCrop((112, 112)),\n            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n        ]\n        video_transform = torchvision.transforms.Compose(vision_transform_list)\n        # if not self.loaded:\n        #     raise gr.Error(\"Error with loading model\")\n        generator = torch.Generator()\n        if seed_textbox != \"\":\n            torch.manual_seed(int(seed_textbox))\n            generator.manual_seed(int(seed_textbox))\n        max_frame_nums = 150\n        frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)\n        if duration >= 10:\n            duration = 10\n        time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(device)\n        time_frames = video_transform(time_frames)\n        time_frames = {\"frames\": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}\n        preds = self.time_detector(time_frames)\n        preds = torch.sigmoid(preds)\n\n        # duration\n        time_condition = [\n            -1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1\n            for i in range(int(1024 / 10 * duration))\n        ]\n        time_condition = time_condition + [-1] * (1024 - len(time_condition))\n        # w -> b c h w\n        time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)\n\n        # Note that clip need fewer frames\n        frames = frames[::10]\n        images = self.image_processor(images=frames, return_tensors=\"pt\").to(device)\n        image_embeddings = self.image_encoder(**images).image_embeds\n        image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)\n        neg_image_embeddings = torch.zeros_like(image_embeddings)\n        image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)\n        self.pipeline.set_ip_adapter_scale(ip_adapter_scale)\n        sample = self.pipeline(\n            prompt=prompt_textbox,\n            negative_prompt=negative_prompt_textbox,\n            ip_adapter_image_embeds=image_embeddings,\n            image=time_condition,\n            controlnet_conditioning_scale=float(temporal_scale),\n            num_inference_steps=sample_step_slider,\n            height=256,\n            width=1024,\n            output_type=\"pt\",\n            generator=generator,\n        )\n        name = \"output\"\n        audio_img = sample.images[0]\n        audio = denormalize_spectrogram(audio_img)\n        audio = self.vocoder.inference(audio, lengths=160000)[0]\n        audio_save_path = osp.join(self.savedir_sample, \"audio\")\n        os.makedirs(audio_save_path, exist_ok=True)\n        audio = audio[: int(duration * 16000)]\n\n        save_path = osp.join(audio_save_path, f\"{name}.wav\")\n        sf.write(save_path, audio, 16000)\n\n        audio = AudioFileClip(osp.join(audio_save_path, f\"{name}.wav\"))\n        video = VideoFileClip(input_video)\n        audio = audio.subclip(0, duration)\n        video.audio = audio\n        video = video.subclip(0, duration)\n        video.write_videofile(osp.join(self.savedir_sample, f\"{name}.mp4\"))\n        save_sample_path = os.path.join(self.savedir_sample, f\"{name}.mp4\")\n\n        return save_sample_path\n\n\ncontroller = FoleyController()\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nwith gr.Blocks(css=css) as demo:\n    gr.HTML(\n        '<h1 style=\"height: 136px; display: flex; align-items: center; justify-content: space-around;\"><span style=\"height: 100%; width:136px;\"><img src=\"file/assets/foleycrafter.png\" alt=\"logo\" style=\"height: 100%; width:auto; object-fit: contain; margin: 0px 0px; padding: 0px 0px;\"></span><strong style=\"font-size: 36px;\">FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</strong></h1>'\n    )\n    gr.HTML(\n        '<p id=\"authors\" style=\"text-align:center; font-size:24px;\"> \\\n        <a href=\"https://github.com/ymzhang0319\">Yiming Zhang</a><sup>1</sup>,&nbsp \\\n        <a href=\"https://github.com/VocodexElysium\">Yicheng Gu</a><sup>2</sup>,&nbsp \\\n        <a href=\"https://zengyh1900.github.io/\">Yanhong Zeng</a><sup>1 †</sup>,&nbsp \\\n        <a href=\"https://github.com/LeoXing1996/\">Zhening Xing</a><sup>1</sup>,&nbsp \\\n        <a href=\"https://github.com/HeCheng0625\">Yuancheng Wang</a><sup>2</sup>,&nbsp \\\n        <a href=\"https://drwuz.com/\">Zhizheng Wu</a><sup>2</sup>,&nbsp \\\n        <a href=\"https://chenkai.site/\">Kai Chen</a><sup>1 †</sup>\\\n        <br>\\\n        <span>\\\n            <sup>1</sup>Shanghai AI Laboratory &nbsp;&nbsp;&nbsp;\\\n            <sup>2</sup>Chinese University of Hong Kong, Shenzhen &nbsp;&nbsp;&nbsp;\\\n            †Corresponding author\\\n        </span>\\\n    </p>'\n    )\n    with gr.Row():\n        gr.Markdown(\n            \"<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a> &ensp;\"  # noqa\n            \"<a href='https://arxiv.org/abs/2407.01494/'>Paper</a> &ensp;\"\n            \"<a href='https://github.com/open-mmlab/foleycrafter'>Code</a> &ensp;\"\n            \"<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>\"\n        )\n\n    with gr.Column(variant=\"panel\"):\n        with gr.Row(equal_height=False):\n            with gr.Column():\n                with gr.Row():\n                    init_img = gr.Video(label=\"Input Video\")\n                with gr.Row():\n                    prompt_textbox = gr.Textbox(value=\"\", label=\"Prompt\", lines=1)\n                with gr.Row():\n                    negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label=\"Negative prompt\", lines=1)\n\n                with gr.Row():\n                    ip_adapter_scale = gr.Slider(label=\"Visual Content Scale\", value=1.0, minimum=0, maximum=1)\n                    temporal_scale = gr.Slider(label=\"Temporal Align Scale\", value=0.2, minimum=0.0, maximum=1.0)\n\n                with gr.Accordion(\"Sampling Settings\", open=False):\n                    with gr.Row():\n                        sampler_dropdown = gr.Dropdown(\n                            label=\"Sampling method\",\n                            choices=list(scheduler_dict.keys()),\n                            value=list(scheduler_dict.keys())[0],\n                        )\n                        sample_step_slider = gr.Slider(\n                            label=\"Sampling steps\", value=25, minimum=10, maximum=100, step=1\n                        )\n                    cfg_scale_slider = gr.Slider(label=\"CFG Scale\", value=7.5, minimum=0, maximum=20)\n\n                with gr.Row():\n                    seed_textbox = gr.Textbox(label=\"Seed\", value=42)\n                    seed_button = gr.Button(value=\"\\U0001f3b2\", elem_classes=\"toolbutton\")\n                seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)\n\n                generate_button = gr.Button(value=\"Generate\", variant=\"primary\")\n\n            with gr.Column():\n                result_video = gr.Video(label=\"Generated Audio\", interactive=False)\n                with gr.Row():\n                    gr.Markdown(\n                        \"<div style='word-spacing: 6px;'><font size='5'><b>Tips</b>: <br> \\\n                        1. With strong temporal visual cues in input video, you can scale up the <b>Temporal Align Scale</b>. <br>\\\n                        2. <b>Visual content scale</b> is the level of semantic alignment with visual content.</font></div> \\\n                    \"\n                    )\n\n        generate_button.click(\n            fn=controller.foley,\n            inputs=[\n                init_img,\n                prompt_textbox,\n                negative_prompt_textbox,\n                ip_adapter_scale,\n                temporal_scale,\n                sampler_dropdown,\n                sample_step_slider,\n                cfg_scale_slider,\n                seed_textbox,\n            ],\n            outputs=[result_video],\n        )\n\n        gr.Examples(\n            examples=[\n                [\"examples/gen3/case1.mp4\", \"\", \"\", 1.0, 0.2, \"DDIM\", 25, 7.5, 33817921],\n                [\"examples/gen3/case3.mp4\", \"\", \"\", 1.0, 0.2, \"DDIM\", 25, 7.5, 94667578],\n                [\"examples/gen3/case5.mp4\", \"\", \"\", 0.75, 0.2, \"DDIM\", 25, 7.5, 92890876],\n                [\"examples/gen3/case6.mp4\", \"\", \"\", 1.0, 0.2, \"DDIM\", 25, 7.5, 77015909],\n            ],\n            inputs=[\n                init_img,\n                prompt_textbox,\n                negative_prompt_textbox,\n                ip_adapter_scale,\n                temporal_scale,\n                sampler_dropdown,\n                sample_step_slider,\n                cfg_scale_slider,\n                seed_textbox,\n            ],\n            cache_examples=True,\n            outputs=[result_video],\n            fn=controller.foley,\n        )\n\n    demo.queue(10)\n    demo.launch(\n        server_name=args.server_name,\n        server_port=args.port,\n        share=args.share,\n        allowed_paths=[\"./assets/foleycrafter.png\"],\n    )\n"
  },
  {
    "path": "foleycrafter/data/__init__.py",
    "content": "from .dataset import AudioSetStrong, CPU_Unpickler, VGGSound, dynamic_range_compression, get_mel, zero_rank_print\nfrom .video_transforms import (\n    CenterCropVideo,\n    KineticsRandomCropResizeVideo,\n    NormalizeVideo,\n    RandomHorizontalFlipVideo,\n    TemporalRandomCrop,\n    ToTensorVideo,\n    UCFCenterCropVideo,\n)\n\n\n__all__ = [\n    \"zero_rank_print\",\n    \"get_mel\",\n    \"dynamic_range_compression\",\n    \"CPU_Unpickler\",\n    \"AudioSetStrong\",\n    \"VGGSound\",\n    \"UCFCenterCropVideo\",\n    \"KineticsRandomCropResizeVideo\",\n    \"CenterCropVideo\",\n    \"NormalizeVideo\",\n    \"ToTensorVideo\",\n    \"RandomHorizontalFlipVideo\",\n    \"TemporalRandomCrop\",\n]\n"
  },
  {
    "path": "foleycrafter/data/dataset.py",
    "content": "import glob\nimport io\nimport pickle\nimport random\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimport torchaudio\nimport torchvision.transforms as transforms\nfrom torch.utils.data.dataset import Dataset\n\n\ndef zero_rank_print(s):\n    if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0):\n        print(\"### \" + s, flush=True)\n\n\n@torch.no_grad()\ndef get_mel(audio_data, audio_cfg):\n    # mel shape: (n_mels, T)\n    mel = torchaudio.transforms.MelSpectrogram(\n        sample_rate=audio_cfg[\"sample_rate\"],\n        n_fft=audio_cfg[\"window_size\"],\n        win_length=audio_cfg[\"window_size\"],\n        hop_length=audio_cfg[\"hop_size\"],\n        center=True,\n        pad_mode=\"reflect\",\n        power=2.0,\n        norm=None,\n        onesided=True,\n        n_mels=64,\n        f_min=audio_cfg[\"fmin\"],\n        f_max=audio_cfg[\"fmax\"],\n    ).to(audio_data.device)\n    mel = mel(audio_data)\n    # we use log mel spectrogram as input\n    mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)\n    return mel  # (T, n_mels)\n\n\ndef dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):\n    \"\"\"\n    PARAMS\n    ------\n    C: compression factor\n    \"\"\"\n    return normalize_fun(torch.clamp(x, min=clip_val) * C)\n\n\nclass CPU_Unpickler(pickle.Unpickler):\n    def find_class(self, module, name):\n        if module == \"torch.storage\" and name == \"_load_from_bytes\":\n            return lambda b: torch.load(io.BytesIO(b), map_location=\"cpu\")\n        else:\n            return super().find_class(module, name)\n\n\nclass AudioSetStrong(Dataset):\n    # read feature and audio\n    def __init__(\n        self,\n        data_path=\"data/AudioSetStrong/train/feature\",\n        video_path=\"data/AudioSetStrong/train/video\",\n    ):\n        super().__init__()\n        self.data_path = data_path\n        self.data_list = list(self.data_path)\n        self.length = len(self.data_list)\n        # get video feature\n        self.video_path = video_path\n        vision_transform_list = [\n            transforms.Resize((128, 128)),\n            transforms.CenterCrop((112, 112)),\n            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n        ]\n        self.video_transform = transforms.Compose(vision_transform_list)\n\n    def get_batch(self, idx):\n        embeds = self.data_list[idx]\n        mel = embeds[\"mel\"]\n        save_bsz = mel.shape[0]\n        audio_info = embeds[\"audio_info\"]\n        text_embeds = embeds[\"text_embeds\"]\n\n        # audio_info['label_list'] = np.array(audio_info['label_list'])\n        audio_info_array = np.array(audio_info[\"label_list\"])\n        prompts = []\n        for i in range(save_bsz):\n            prompts.append(\", \".join(audio_info_array[i, : audio_info[\"event_num\"][i]].tolist()))\n\n        return mel, audio_info, text_embeds, prompts\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, idx):\n        while True:\n            try:\n                mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx)\n                break\n            except Exception:\n                zero_rank_print(\" >>> load error <<<\")\n                idx = random.randint(0, self.length - 1)\n        sample = {\n            \"mel\": mel,\n            \"audio_info\": audio_info,\n            \"text_embeds\": text_embeds,\n            \"prompts\": prompts,\n            \"videos\": videos,\n        }\n        return sample\n\n\nclass VGGSound(Dataset):\n    # read feature and audio\n    def __init__(\n        self,\n        data_path=\"data/VGGSound/train/video\",\n        visual_data_path=\"data/VGGSound/train/feature\",\n    ):\n        super().__init__()\n        self.data_path = data_path\n        self.visual_data_path = visual_data_path\n        self.embeds_list = glob.glob(f\"{self.data_path}/*.pt\")\n        self.visual_list = glob.glob(f\"{self.visual_data_path}/*.pt\")\n        self.length = len(self.embeds_list)\n\n    def get_batch(self, idx):\n        embeds = torch.load(self.embeds_list[idx], map_location=\"cpu\")\n        visual_embeds = torch.load(self.visual_list[idx], map_location=\"cpu\")\n\n        # audio_embeds  = embeds['audio_embeds']\n        visual_embeds = visual_embeds[\"visual_embeds\"]\n        # video_name = embeds[\"video_name\"]\n        text = embeds[\"text\"]\n        mel = embeds[\"mel\"]\n\n        audio = mel\n\n        return visual_embeds, audio, text\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, idx):\n        while True:\n            try:\n                visual_embeds, audio, text = self.get_batch(idx)\n                break\n            except Exception:\n                zero_rank_print(\"load error\")\n                idx = random.randint(0, self.length - 1)\n        sample = {\"visual_embeds\": visual_embeds, \"audio\": audio, \"text\": text}\n        return sample\n"
  },
  {
    "path": "foleycrafter/data/video_transforms.py",
    "content": "import numbers\nimport random\n\nimport torch\n\n\ndef _is_tensor_video_clip(clip):\n    if not torch.is_tensor(clip):\n        raise TypeError(\"clip should be Tensor. Got %s\" % type(clip))\n\n    if not clip.ndimension() == 4:\n        raise ValueError(\"clip should be 4D. Got %dD\" % clip.dim())\n\n    return True\n\n\ndef crop(clip, i, j, h, w):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n    \"\"\"\n    if len(clip.size()) != 4:\n        raise ValueError(\"clip should be a 4D tensor\")\n    return clip[..., i : i + h, j : j + w]\n\n\ndef resize(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(f\"target size should be tuple (height, width), instead got {target_size}\")\n    return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)\n\n\ndef resize_scale(clip, target_size, interpolation_mode):\n    if len(target_size) != 2:\n        raise ValueError(f\"target size should be tuple (height, width), instead got {target_size}\")\n    _, _, H, W = clip.shape\n    scale_ = target_size[0] / min(H, W)\n    return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)\n\n\ndef resized_crop(clip, i, j, h, w, size, interpolation_mode=\"bilinear\"):\n    \"\"\"\n    Do spatial cropping and resizing to the video clip\n    Args:\n        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        i (int): i in (i,j) i.e coordinates of the upper left corner.\n        j (int): j in (i,j) i.e coordinates of the upper left corner.\n        h (int): Height of the cropped region.\n        w (int): Width of the cropped region.\n        size (tuple(int, int)): height and width of resized clip\n    Returns:\n        clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    clip = crop(clip, i, j, h, w)\n    clip = resize(clip, size, interpolation_mode)\n    return clip\n\n\ndef center_crop(clip, crop_size):\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n    th, tw = crop_size\n    if h < th or w < tw:\n        raise ValueError(\"height and width must be no smaller than crop_size\")\n\n    i = int(round((h - th) / 2.0))\n    j = int(round((w - tw) / 2.0))\n    return crop(clip, i, j, th, tw)\n\n\ndef random_shift_crop(clip):\n    \"\"\"\n    Slide along the long edge, with the short edge as crop size\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    h, w = clip.size(-2), clip.size(-1)\n\n    if h <= w:\n        # long_edge = w\n        short_edge = h\n    else:\n        # long_edge = h\n        short_edge = w\n\n    th, tw = short_edge, short_edge\n\n    i = torch.randint(0, h - th + 1, size=(1,)).item()\n    j = torch.randint(0, w - tw + 1, size=(1,)).item()\n    return crop(clip, i, j, th, tw)\n\n\ndef to_tensor(clip):\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    Args:\n        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)\n    Return:\n        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n    \"\"\"\n    _is_tensor_video_clip(clip)\n    if not clip.dtype == torch.uint8:\n        raise TypeError(\"clip tensor should have data type uint8. Got %s\" % str(clip.dtype))\n    # return clip.float().permute(3, 0, 1, 2) / 255.0\n    return clip.float() / 255.0\n\n\ndef normalize(clip, mean, std, inplace=False):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)\n        mean (tuple): pixel RGB mean. Size is (3)\n        std (tuple): pixel standard deviation. Size is (3)\n    Returns:\n        normalized clip (torch.tensor): Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    if not inplace:\n        clip = clip.clone()\n    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)\n    print(mean)\n    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)\n    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])\n    return clip\n\n\ndef hflip(clip):\n    \"\"\"\n    Args:\n        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)\n    Returns:\n        flipped clip (torch.tensor): Size is (T, C, H, W)\n    \"\"\"\n    if not _is_tensor_video_clip(clip):\n        raise ValueError(\"clip should be a 4D torch.tensor\")\n    return clip.flip(-1)\n\n\nclass RandomCropVideo:\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: randomly cropped video clip.\n                size is (T, C, OH, OW)\n        \"\"\"\n        i, j, h, w = self.get_params(clip)\n        return crop(clip, i, j, h, w)\n\n    def get_params(self, clip):\n        h, w = clip.shape[-2:]\n        th, tw = self.size\n\n        if h < th or w < tw:\n            raise ValueError(f\"Required crop size {(th, tw)} is larger than input image size {(h, w)}\")\n\n        if w == tw and h == th:\n            return 0, 0, h, w\n\n        i = torch.randint(0, h - th + 1, size=(1,)).item()\n        j = torch.randint(0, w - tw + 1, size=(1,)).item()\n\n        return i, j, th, tw\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size})\"\n\n\nclass UCFCenterCropVideo:\n    def __init__(\n        self,\n        size,\n        interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: scale resized / center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)\n        clip_center_crop = center_crop(clip_resize, self.size)\n        return clip_center_crop\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass KineticsRandomCropResizeVideo:\n    \"\"\"\n    Slide along the long edge, with the short edge as crop size. And resie to the desired size.\n    \"\"\"\n\n    def __init__(\n        self,\n        size,\n        interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        clip_random_crop = random_shift_crop(clip)\n        clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)\n        return clip_resize\n\n\nclass CenterCropVideo:\n    def __init__(\n        self,\n        size,\n        interpolation_mode=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            if len(size) != 2:\n                raise ValueError(f\"size should be tuple (height, width), instead got {size}\")\n            self.size = size\n        else:\n            self.size = (size, size)\n\n        self.interpolation_mode = interpolation_mode\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)\n        Returns:\n            torch.tensor: center cropped video clip.\n                size is (T, C, crop_size, crop_size)\n        \"\"\"\n        clip_center_crop = center_crop(clip, self.size)\n        return clip_center_crop\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}\"\n\n\nclass NormalizeVideo:\n    \"\"\"\n    Normalize the video clip by mean subtraction and division by standard deviation\n    Args:\n        mean (3-tuple): pixel RGB mean\n        std (3-tuple): pixel RGB standard deviation\n        inplace (boolean): whether do in-place normalization\n    \"\"\"\n\n    def __init__(self, mean, std, inplace=False):\n        self.mean = mean\n        self.std = std\n        self.inplace = inplace\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)\n        \"\"\"\n        return normalize(clip, self.mean, self.std, self.inplace)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})\"\n\n\nclass ToTensorVideo:\n    \"\"\"\n    Convert tensor data type from uint8 to float, divide value by 255.0 and\n    permute the dimensions of clip tensor\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)\n        \"\"\"\n        return to_tensor(clip)\n\n    def __repr__(self) -> str:\n        return self.__class__.__name__\n\n\nclass RandomHorizontalFlipVideo:\n    \"\"\"\n    Flip the video clip along the horizontal direction with a given probability\n    Args:\n        p (float): probability of the clip being flipped. Default value is 0.5\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        self.p = p\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n            clip (torch.tensor): Size is (T, C, H, W)\n        Return:\n            clip (torch.tensor): Size is (T, C, H, W)\n        \"\"\"\n        if random.random() < self.p:\n            clip = hflip(clip)\n        return clip\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(p={self.p})\"\n\n\n#  ------------------------------------------------------------\n#  ---------------------  Sampling  ---------------------------\n#  ------------------------------------------------------------\nclass TemporalRandomCrop(object):\n    \"\"\"Temporally crop the given frame indices at a random location.\n\n    Args:\n            size (int): Desired length of frames will be seen in the model.\n    \"\"\"\n\n    def __init__(self, size):\n        self.size = size\n\n    def __call__(self, total_frames):\n        rand_end = max(0, total_frames - self.size - 1)\n        begin_index = random.randint(0, rand_end)\n        end_index = min(begin_index + self.size, total_frames)\n        return begin_index, end_index\n\n\nif __name__ == \"__main__\":\n    import os\n\n    import numpy as np\n    import torchvision.io as io\n    from torchvision import transforms\n    from torchvision.utils import save_image\n\n    vframes, aframes, info = io.read_video(filename=\"./v_Archery_g01_c03.avi\", pts_unit=\"sec\", output_format=\"TCHW\")\n\n    trans = transforms.Compose(\n        [\n            ToTensorVideo(),\n            RandomHorizontalFlipVideo(),\n            UCFCenterCropVideo(512),\n            # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),\n        ]\n    )\n\n    target_video_len = 32\n    frame_interval = 1\n    total_frames = len(vframes)\n    print(total_frames)\n\n    temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)\n\n    # Sampling video frames\n    start_frame_ind, end_frame_ind = temporal_sample(total_frames)\n    # print(start_frame_ind)\n    # print(end_frame_ind)\n    assert end_frame_ind - start_frame_ind >= target_video_len\n    frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)\n\n    select_vframes = vframes[frame_indice]\n\n    select_vframes_trans = trans(select_vframes)\n\n    select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)\n\n    io.write_video(\"./test.avi\", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)\n\n    for i in range(target_video_len):\n        save_image(\n            select_vframes_trans[i], os.path.join(\"./test000\", \"%04d.png\" % i), normalize=True, value_range=(-1, 1)\n        )\n"
  },
  {
    "path": "foleycrafter/models/adapters/attention_processor.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass AttnProcessor(nn.Module):\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor(nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # for ip-adapter\n        ip_key = self.to_k_ip(ip_hidden_states)\n        ip_value = self.to_v_ip(ip_hidden_states)\n\n        ip_key = attn.head_to_batch_dim(ip_key)\n        ip_value = attn.head_to_batch_dim(ip_value)\n\n        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)\n        self.attn_map = ip_attention_probs\n        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)\n        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)\n\n        hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass AttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass AttnProcessor2_0WithProjection(torch.nn.Module):\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.before_proj_size = 1024\n        self.after_proj_size = 768\n        self.visual_proj = nn.Linear(self.before_proj_size, self.after_proj_size)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n        # encoder_hidden_states = self.visual_proj(encoder_hidden_states)\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater for PyTorch 2.0.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):\n        super().__init__()\n\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # for ip-adapter\n        ip_key = self.to_k_ip(ip_hidden_states)\n        ip_value = self.to_v_ip(ip_hidden_states)\n\n        ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        ip_hidden_states = F.scaled_dot_product_attention(\n            query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False\n        )\n        with torch.no_grad():\n            self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)\n            # print(self.attn_map.shape)\n\n        ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        ip_hidden_states = ip_hidden_states.to(query.dtype)\n\n        hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\n## for controlnet\nclass CNAttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(self, num_tokens=4):\n        self.num_tokens = num_tokens\n\n    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass CNAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self, num_tokens=4):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.num_tokens = num_tokens\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n"
  },
  {
    "path": "foleycrafter/models/adapters/ip_adapter.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n\n    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):\n        super().__init__()\n        self.unet = unet\n        self.image_proj_model = image_proj_model\n        self.adapter_modules = adapter_modules\n\n        if ckpt_path is not None:\n            self.load_from_checkpoint(ckpt_path)\n\n    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):\n        ip_tokens = self.image_proj_model(image_embeds)\n        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)\n        # Predict the noise residual\n        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample\n        return noise_pred\n\n    def load_from_checkpoint(self, ckpt_path: str):\n        # Calculate original checksums\n        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n\n        # Load state dict for image_proj_model and adapter_modules\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=True)\n        self.adapter_modules.load_state_dict(state_dict[\"ip_adapter\"], strict=True)\n\n        # Calculate new checksums\n        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        # Verify if the weights have changed\n        assert orig_ip_proj_sum != new_ip_proj_sum, \"Weights of image_proj_model did not change!\"\n        assert orig_adapter_sum != new_adapter_sum, \"Weights of adapter_modules did not change!\"\n\n        print(f\"Successfully loaded weights from checkpoint {ckpt_path}\")\n\n\nclass ImageProjModel(torch.nn.Module):\n    \"\"\"Projection Model\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.cross_attention_dim = cross_attention_dim\n        self.clip_extra_context_tokens = clip_extra_context_tokens\n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)\n\n    def forward(self, image_embeds):\n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\nclass MLPProjModel(torch.nn.Module):\n    \"\"\"SD model with image prompt\"\"\"\n\n    def zero_initialize(module):\n        for param in module.parameters():\n            param.data.zero_()\n\n    def zero_initialize_last_layer(module):\n        last_layer = None\n        for module_name, layer in module.named_modules():\n            if isinstance(layer, torch.nn.Linear):\n                last_layer = layer\n\n        if last_layer is not None:\n            last_layer.weight.data.zero_()\n            last_layer.bias.data.zero_()\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):\n        super().__init__()\n\n        self.proj = torch.nn.Sequential(\n            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),\n            torch.nn.GELU(),\n            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),\n            torch.nn.LayerNorm(cross_attention_dim),\n        )\n        # zero initialize the last layer\n        # self.zero_initialize_last_layer()\n\n    def forward(self, image_embeds):\n        clip_extra_context_tokens = self.proj(image_embeds)\n        return clip_extra_context_tokens\n\n\nclass V2AMapperMLP(torch.nn.Module):\n    def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4):\n        super().__init__()\n        self.proj = torch.nn.Sequential(\n            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult),\n            torch.nn.GELU(),\n            torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim),\n            torch.nn.LayerNorm(cross_attention_dim),\n        )\n\n    def forward(self, image_embeds):\n        clip_extra_context_tokens = self.proj(image_embeds)\n        return clip_extra_context_tokens\n\n\nclass TimeProjModel(torch.nn.Module):\n    def __init__(self, positive_len, out_dim, feature_type=\"text-only\", frame_nums: int = 64):\n        super().__init__()\n        self.positive_len = positive_len\n        self.out_dim = out_dim\n\n        self.position_dim = frame_nums\n\n        if isinstance(out_dim, tuple):\n            out_dim = out_dim[0]\n\n        if feature_type == \"text-only\":\n            self.linears = nn.Sequential(\n                nn.Linear(self.positive_len + self.position_dim, 512),\n                nn.SiLU(),\n                nn.Linear(512, 512),\n                nn.SiLU(),\n                nn.Linear(512, out_dim),\n            )\n            self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))\n\n        elif feature_type == \"text-image\":\n            self.linears_text = nn.Sequential(\n                nn.Linear(self.positive_len + self.position_dim, 512),\n                nn.SiLU(),\n                nn.Linear(512, 512),\n                nn.SiLU(),\n                nn.Linear(512, out_dim),\n            )\n            self.linears_image = nn.Sequential(\n                nn.Linear(self.positive_len + self.position_dim, 512),\n                nn.SiLU(),\n                nn.Linear(512, 512),\n                nn.SiLU(),\n                nn.Linear(512, out_dim),\n            )\n            self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))\n            self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))\n\n        # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))\n\n    def forward(\n        self,\n        boxes,\n        masks,\n        positive_embeddings=None,\n    ):\n        masks = masks.unsqueeze(-1)\n\n        # # embedding position (it may includes padding as placeholder)\n        # xyxy_embedding = self.fourier_embedder(boxes)  # B*N*4 -> B*N*C\n\n        # # learnable null embedding\n        # xyxy_null = self.null_position_feature.view(1, 1, -1)\n\n        # # replace padding with learnable null embedding\n        # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null\n\n        time_embeds = boxes\n\n        # positionet with text only information\n        if positive_embeddings is not None:\n            # learnable null embedding\n            positive_null = self.null_positive_feature.view(1, 1, -1)\n\n            # replace padding with learnable null embedding\n            positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null\n\n            objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1))\n\n        # positionet with text and image information\n        else:\n            raise NotImplementedError\n\n        return objs\n"
  },
  {
    "path": "foleycrafter/models/adapters/resampler.py",
    "content": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom einops.layers.torch import Rearrange\n\n\n# FFN\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head**-0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass Resampler(nn.Module):\n    def __init__(\n        self,\n        dim=1024,\n        depth=8,\n        dim_head=64,\n        heads=16,\n        num_queries=8,\n        embedding_dim=768,\n        output_dim=1024,\n        ff_mult=4,\n        max_seq_len: int = 257,  # CLIP tokens + CLS token\n        apply_pos_emb: bool = False,\n        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence\n    ):\n        super().__init__()\n        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None\n\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)\n\n        self.proj_in = nn.Linear(embedding_dim, dim)\n\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.to_latents_from_mean_pooled_seq = (\n            nn.Sequential(\n                nn.LayerNorm(dim),\n                nn.Linear(dim, dim * num_latents_mean_pooled),\n                Rearrange(\"b (n d) -> b n d\", n=num_latents_mean_pooled),\n            )\n            if num_latents_mean_pooled > 0\n            else None\n        )\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        FeedForward(dim=dim, mult=ff_mult),\n                    ]\n                )\n            )\n\n    def forward(self, x):\n        if self.pos_emb is not None:\n            n, device = x.shape[1], x.device\n            pos_emb = self.pos_emb(torch.arange(n, device=device))\n            x = x + pos_emb\n\n        latents = self.latents.repeat(x.size(0), 1, 1)\n\n        x = self.proj_in(x)\n\n        if self.to_latents_from_mean_pooled_seq:\n            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))\n            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)\n            latents = torch.cat((meanpooled_latents, latents), dim=-2)\n\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n\ndef masked_mean(t, *, dim, mask=None):\n    if mask is None:\n        return t.mean(dim=dim)\n\n    denom = mask.sum(dim=dim, keepdim=True)\n    mask = rearrange(mask, \"b n -> b n 1\")\n    masked_t = t.masked_fill(~mask, 0.0)\n\n    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)\n"
  },
  {
    "path": "foleycrafter/models/adapters/transformer.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\n\nclass Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0):\n        super().__init__()\n        self.embed_dim = hidden_size\n        self.num_heads = num_attention_heads\n        self.head_dim = attention_head_dim\n\n        self.scale = self.head_dim**-0.5\n        self.dropout = attention_dropout\n\n        self.inner_dim = self.head_dim * self.num_heads\n\n        self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)\n        self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)\n        self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)\n        self.out_proj = nn.Linear(self.inner_dim, self.embed_dim)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, tgt_len, embed_dim = hidden_states.size()\n\n        # get query proj\n        query_states = self.q_proj(hidden_states) * self.scale\n        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        # apply the causal_attention_mask first\n        if causal_attention_mask is not None:\n            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is\"\n                    f\" {causal_attention_mask.size()}\"\n                )\n            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        if output_attentions:\n            # this operation is a bit akward, but it's required to\n            # make sure that attn_weights keeps its gradient.\n            # In order to do so, attn_weights have to reshaped\n            # twice and have to be reused in the following\n            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)\n        else:\n            attn_weights_reshaped = None\n\n        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights_reshaped\n\n\nclass MLP(nn.Module):\n    def __init__(self, hidden_size, intermediate_size, mult=4):\n        super().__init__()\n        self.activation_fn = nn.SiLU()\n        self.fc1 = nn.Linear(hidden_size, intermediate_size * mult)\n        self.fc2 = nn.Linear(intermediate_size * mult, hidden_size)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Transformer(nn.Module):\n    def __init__(self, depth=12):\n        super().__init__()\n        self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)])\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor = None,\n        causal_attention_mask: torch.Tensor = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n                `(config.encoder_attention_heads,)`.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n        \"\"\"\n        for layer in self.layers:\n            hidden_states = layer(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                causal_attention_mask=causal_attention_mask,\n                output_attentions=output_attentions,\n            )\n\n        return hidden_states\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size=512,\n        num_attention_heads=12,\n        attention_head_dim=64,\n        attention_dropout=0.0,\n        dropout=0.0,\n        eps=1e-5,\n    ):\n        super().__init__()\n        self.embed_dim = hidden_size\n        self.self_attn = Attention(\n            hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim\n        )\n        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)\n        self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)\n        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor = None,\n        causal_attention_mask: torch.Tensor = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n                `(config.encoder_attention_heads,)`.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs[0]\n\n\nclass DiffusionTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size=512,\n        num_attention_heads=12,\n        attention_head_dim=64,\n        attention_dropout=0.0,\n        dropout=0.0,\n        eps=1e-5,\n    ):\n        super().__init__()\n        self.embed_dim = hidden_size\n        self.self_attn = Attention(\n            hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim\n        )\n        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)\n        self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)\n        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)\n        self.output_token = nn.Parameter(torch.randn(1, hidden_size))\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor = None,\n        causal_attention_mask: torch.Tensor = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n                `(config.encoder_attention_heads,)`.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n        \"\"\"\n        output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)\n        hidden_states = torch.cat([output_token, hidden_states], dim=1)\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs[0][:, 0:1, ...]\n\n\nclass V2AMapperMLP(nn.Module):\n    def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):\n        super().__init__()\n        self.linear = nn.Linear(input_dim, input_dim * expansion_rate)\n        self.silu = nn.SiLU()\n        self.layer_norm = nn.LayerNorm(input_dim * expansion_rate)\n        self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim)\n\n    def forward(self, x):\n        x = self.linear(x)\n        x = self.silu(x)\n        x = self.layer_norm(x)\n        x = self.linear2(x)\n\n        return x\n\n\nclass ImageProjModel(torch.nn.Module):\n    \"\"\"Projection Model\"\"\"\n\n    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):\n        super().__init__()\n\n        self.cross_attention_dim = cross_attention_dim\n        self.clip_extra_context_tokens = clip_extra_context_tokens\n        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)\n        self.norm = torch.nn.LayerNorm(cross_attention_dim)\n\n        self.zero_initialize_last_layer()\n\n    def zero_initialize_last_layer(module):\n        last_layer = None\n        for module_name, layer in module.named_modules():\n            if isinstance(layer, torch.nn.Linear):\n                last_layer = layer\n\n        if last_layer is not None:\n            last_layer.weight.data.zero_()\n            last_layer.bias.data.zero_()\n\n    def forward(self, image_embeds):\n        embeds = image_embeds\n        clip_extra_context_tokens = self.proj(embeds).reshape(\n            -1, self.clip_extra_context_tokens, self.cross_attention_dim\n        )\n        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)\n        return clip_extra_context_tokens\n\n\nclass VisionAudioAdapter(torch.nn.Module):\n    def __init__(\n        self,\n        embedding_size=768,\n        expand_dim=4,\n        token_num=4,\n    ):\n        super().__init__()\n\n        self.mapper = V2AMapperMLP(\n            embedding_size,\n            embedding_size,\n            expansion_rate=expand_dim,\n        )\n\n        self.proj = ImageProjModel(\n            cross_attention_dim=embedding_size,\n            clip_embeddings_dim=embedding_size,\n            clip_extra_context_tokens=token_num,\n        )\n\n    def forward(self, image_embeds):\n        image_embeds = self.mapper(image_embeds)\n        image_embeds = self.proj(image_embeds)\n        return image_embeds\n"
  },
  {
    "path": "foleycrafter/models/adapters/utils.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\n\n\nattn_maps = {}\n\n\ndef hook_fn(name):\n    def forward_hook(module, input, output):\n        if hasattr(module.processor, \"attn_map\"):\n            attn_maps[name] = module.processor.attn_map\n            del module.processor.attn_map\n\n    return forward_hook\n\n\ndef register_cross_attention_hook(unet):\n    for name, module in unet.named_modules():\n        if name.split(\".\")[-1].startswith(\"attn2\"):\n            module.register_forward_hook(hook_fn(name))\n\n    return unet\n\n\ndef upscale(attn_map, target_size):\n    attn_map = torch.mean(attn_map, dim=0)\n    attn_map = attn_map.permute(1, 0)\n    temp_size = None\n\n    for i in range(0, 5):\n        scale = 2**i\n        if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[1] * 64:\n            temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8))\n            break\n\n    assert temp_size is not None, \"temp_size cannot is None\"\n\n    attn_map = attn_map.view(attn_map.shape[0], *temp_size)\n\n    attn_map = F.interpolate(\n        attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode=\"bilinear\", align_corners=False\n    )[0]\n\n    attn_map = torch.softmax(attn_map, dim=0)\n    return attn_map\n\n\ndef get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):\n    idx = 0 if instance_or_negative else 1\n    net_attn_maps = []\n\n    for name, attn_map in attn_maps.items():\n        attn_map = attn_map.cpu() if detach else attn_map\n        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()\n        attn_map = upscale(attn_map, image_size)\n        net_attn_maps.append(attn_map)\n\n    net_attn_maps = torch.mean(torch.stack(net_attn_maps, dim=0), dim=0)\n\n    return net_attn_maps\n\n\ndef attnmaps2images(net_attn_maps):\n    # total_attn_scores = 0\n    images = []\n\n    for attn_map in net_attn_maps:\n        attn_map = attn_map.cpu().numpy()\n        # total_attn_scores += attn_map.mean().item()\n\n        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255\n        normalized_attn_map = normalized_attn_map.astype(np.uint8)\n        # print(\"norm: \", normalized_attn_map.shape)\n        image = Image.fromarray(normalized_attn_map)\n\n        # image = fix_save_attn_map(attn_map)\n        images.append(image)\n\n    # print(total_attn_scores)\n    return images\n\n\ndef is_torch2_available():\n    return hasattr(F, \"scaled_dot_product_attention\")\n"
  },
  {
    "path": "foleycrafter/models/auffusion/attention.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.models.activations import GEGLU, GELU, ApproximateGELU\nfrom diffusers.models.embeddings import SinusoidalPositionalEmbedding\nfrom diffusers.models.lora import LoRACompatibleLinear\nfrom diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm\nfrom diffusers.utils import USE_PEFT_BACKEND\nfrom diffusers.utils.torch_utils import maybe_allow_in_graph\nfrom foleycrafter.models.auffusion.attention_processor import Attention\n\n\ndef _chunked_feed_forward(\n    ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None\n):\n    # \"feed_forward_chunk_size\" can be used to save memory\n    if hidden_states.shape[chunk_dim] % chunk_size != 0:\n        raise ValueError(\n            f\"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.\"\n        )\n\n    num_chunks = hidden_states.shape[chunk_dim] // chunk_size\n    if lora_scale is None:\n        ff_output = torch.cat(\n            [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],\n            dim=chunk_dim,\n        )\n    else:\n        # TODO(Patrick): LoRA scale can be removed once PEFT refactor is complete\n        ff_output = torch.cat(\n            [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],\n            dim=chunk_dim,\n        )\n\n    return ff_output\n\n\n@maybe_allow_in_graph\nclass GatedSelfAttentionDense(nn.Module):\n    r\"\"\"\n    A gated self-attention dense layer that combines visual features and object features.\n\n    Parameters:\n        query_dim (`int`): The number of channels in the query.\n        context_dim (`int`): The number of channels in the context.\n        n_heads (`int`): The number of heads to use for attention.\n        d_head (`int`): The number of channels in each head.\n    \"\"\"\n\n    def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):\n        super().__init__()\n\n        # we need a linear projection since we need cat visual feature and obj feature\n        self.linear = nn.Linear(context_dim, query_dim)\n\n        self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)\n        self.ff = FeedForward(query_dim, activation_fn=\"geglu\")\n\n        self.norm1 = nn.LayerNorm(query_dim)\n        self.norm2 = nn.LayerNorm(query_dim)\n\n        self.register_parameter(\"alpha_attn\", nn.Parameter(torch.tensor(0.0)))\n        self.register_parameter(\"alpha_dense\", nn.Parameter(torch.tensor(0.0)))\n\n        self.enabled = True\n\n    def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:\n        if not self.enabled:\n            return x\n\n        n_visual = x.shape[1]\n        objs = self.linear(objs)\n\n        x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]\n        x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))\n\n        return x\n\n\n@maybe_allow_in_graph\nclass BasicTransformerBlock(nn.Module):\n    r\"\"\"\n    A basic Transformer block.\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm (:\n            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.\n        attention_bias (:\n            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.\n        only_cross_attention (`bool`, *optional*):\n            Whether to use only cross-attention layers. In this case two cross attention layers are used.\n        double_self_attention (`bool`, *optional*):\n            Whether to use two self-attention layers. In this case no cross attention layers are used.\n        upcast_attention (`bool`, *optional*):\n            Whether to upcast the attention computation to float32. This is useful for mixed precision training.\n        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):\n            Whether to use learnable elementwise affine parameters for normalization.\n        norm_type (`str`, *optional*, defaults to `\"layer_norm\"`):\n            The normalization layer to use. Can be `\"layer_norm\"`, `\"ada_norm\"` or `\"ada_norm_zero\"`.\n        final_dropout (`bool` *optional*, defaults to False):\n            Whether to apply a final dropout after the last feed-forward layer.\n        attention_type (`str`, *optional*, defaults to `\"default\"`):\n            The type of attention to use. Can be `\"default\"` or `\"gated\"` or `\"gated-text-image\"`.\n        positional_embeddings (`str`, *optional*, defaults to `None`):\n            The type of positional embeddings to apply to.\n        num_positional_embeddings (`int`, *optional*, defaults to `None`):\n            The maximum number of positional embeddings to apply.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout=0.0,\n        cross_attention_dim: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        attention_bias: bool = False,\n        only_cross_attention: bool = False,\n        double_self_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_elementwise_affine: bool = True,\n        norm_type: str = \"layer_norm\",  # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'\n        norm_eps: float = 1e-5,\n        final_dropout: bool = False,\n        attention_type: str = \"default\",\n        positional_embeddings: Optional[str] = None,\n        num_positional_embeddings: Optional[int] = None,\n        ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,\n        ada_norm_bias: Optional[int] = None,\n        ff_inner_dim: Optional[int] = None,\n        ff_bias: bool = True,\n        attention_out_bias: bool = True,\n    ):\n        super().__init__()\n        self.only_cross_attention = only_cross_attention\n\n        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == \"ada_norm_zero\"\n        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == \"ada_norm\"\n        self.use_ada_layer_norm_single = norm_type == \"ada_norm_single\"\n        self.use_layer_norm = norm_type == \"layer_norm\"\n        self.use_ada_layer_norm_continuous = norm_type == \"ada_norm_continuous\"\n\n        if norm_type in (\"ada_norm\", \"ada_norm_zero\") and num_embeds_ada_norm is None:\n            raise ValueError(\n                f\"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to\"\n                f\" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}.\"\n            )\n\n        if positional_embeddings and (num_positional_embeddings is None):\n            raise ValueError(\n                \"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined.\"\n            )\n\n        if positional_embeddings == \"sinusoidal\":\n            self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)\n        else:\n            self.pos_embed = None\n\n        # Define 3 blocks. Each block has its own normalization layer.\n        # 1. Self-Attn\n        if self.use_ada_layer_norm:\n            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)\n        elif self.use_ada_layer_norm_zero:\n            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)\n        elif self.use_ada_layer_norm_continuous:\n            self.norm1 = AdaLayerNormContinuous(\n                dim,\n                ada_norm_continous_conditioning_embedding_dim,\n                norm_elementwise_affine,\n                norm_eps,\n                ada_norm_bias,\n                \"rms_norm\",\n            )\n        else:\n            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)\n\n        self.attn1 = Attention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            cross_attention_dim=cross_attention_dim if (only_cross_attention and not double_self_attention) else None,\n            upcast_attention=upcast_attention,\n            out_bias=attention_out_bias,\n        )\n\n        # 2. Cross-Attn\n        if cross_attention_dim is not None or double_self_attention:\n            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.\n            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during\n            # the second cross attention block.\n            if self.use_ada_layer_norm:\n                self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)\n            elif self.use_ada_layer_norm_continuous:\n                self.norm2 = AdaLayerNormContinuous(\n                    dim,\n                    ada_norm_continous_conditioning_embedding_dim,\n                    norm_elementwise_affine,\n                    norm_eps,\n                    ada_norm_bias,\n                    \"rms_norm\",\n                )\n            else:\n                self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)\n\n            self.attn2 = Attention(\n                query_dim=dim,\n                cross_attention_dim=cross_attention_dim if not double_self_attention else None,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                upcast_attention=upcast_attention,\n                out_bias=attention_out_bias,\n            )  # is self-attn if encoder_hidden_states is none\n        else:\n            self.norm2 = None\n            self.attn2 = None\n\n        # 3. Feed-forward\n        if self.use_ada_layer_norm_continuous:\n            self.norm3 = AdaLayerNormContinuous(\n                dim,\n                ada_norm_continous_conditioning_embedding_dim,\n                norm_elementwise_affine,\n                norm_eps,\n                ada_norm_bias,\n                \"layer_norm\",\n            )\n        elif not self.use_ada_layer_norm_single:\n            self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)\n\n        self.ff = FeedForward(\n            dim,\n            dropout=dropout,\n            activation_fn=activation_fn,\n            final_dropout=final_dropout,\n            inner_dim=ff_inner_dim,\n            bias=ff_bias,\n        )\n\n        # 4. Fuser\n        if attention_type == \"gated\" or attention_type == \"gated-text-image\":\n            self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)\n\n        # 5. Scale-shift for PixArt-Alpha.\n        if self.use_ada_layer_norm_single:\n            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)\n\n        # let chunk size default to None\n        self._chunk_size = None\n        self._chunk_dim = 0\n\n    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):\n        # Sets chunk feed-forward\n        self._chunk_size = chunk_size\n        self._chunk_dim = dim\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        class_labels: Optional[torch.LongTensor] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> torch.FloatTensor:\n        # Notice that normalization is always applied before the real computation in the following blocks.\n        # 0. Self-Attention\n        batch_size = hidden_states.shape[0]\n\n        if self.use_ada_layer_norm:\n            norm_hidden_states = self.norm1(hidden_states, timestep)\n        elif self.use_ada_layer_norm_zero:\n            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n            )\n        elif self.use_layer_norm:\n            norm_hidden_states = self.norm1(hidden_states)\n        elif self.use_ada_layer_norm_continuous:\n            norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs[\"pooled_text_emb\"])\n        elif self.use_ada_layer_norm_single:\n            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n                self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)\n            ).chunk(6, dim=1)\n            norm_hidden_states = self.norm1(hidden_states)\n            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa\n            norm_hidden_states = norm_hidden_states.squeeze(1)\n        else:\n            raise ValueError(\"Incorrect norm used\")\n\n        if self.pos_embed is not None:\n            norm_hidden_states = self.pos_embed(norm_hidden_states)\n\n        # 1. Retrieve lora scale.\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n\n        # 2. Prepare GLIGEN inputs\n        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}\n        gligen_kwargs = cross_attention_kwargs.pop(\"gligen\", None)\n\n        attn_output = self.attn1(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n            attention_mask=attention_mask,\n            **cross_attention_kwargs,\n        )\n        if self.use_ada_layer_norm_zero:\n            attn_output = gate_msa.unsqueeze(1) * attn_output\n        elif self.use_ada_layer_norm_single:\n            attn_output = gate_msa * attn_output\n\n        hidden_states = attn_output + hidden_states\n        if hidden_states.ndim == 4:\n            hidden_states = hidden_states.squeeze(1)\n\n        # 2.5 GLIGEN Control\n        if gligen_kwargs is not None:\n            hidden_states = self.fuser(hidden_states, gligen_kwargs[\"objs\"])\n\n        # 3. Cross-Attention\n        if self.attn2 is not None:\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm2(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero or self.use_layer_norm:\n                norm_hidden_states = self.norm2(hidden_states)\n            elif self.use_ada_layer_norm_single:\n                # For PixArt norm2 isn't applied here:\n                # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103\n                norm_hidden_states = hidden_states\n            elif self.use_ada_layer_norm_continuous:\n                norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs[\"pooled_text_emb\"])\n            else:\n                raise ValueError(\"Incorrect norm\")\n\n            if self.pos_embed is not None and self.use_ada_layer_norm_single is False:\n                norm_hidden_states = self.pos_embed(norm_hidden_states)\n\n            attn_output = self.attn2(\n                norm_hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=encoder_attention_mask,\n                **cross_attention_kwargs,\n            )\n            hidden_states = attn_output + hidden_states\n\n        # 4. Feed-forward\n        if self.use_ada_layer_norm_continuous:\n            norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs[\"pooled_text_emb\"])\n        elif not self.use_ada_layer_norm_single:\n            norm_hidden_states = self.norm3(hidden_states)\n\n        if self.use_ada_layer_norm_zero:\n            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n        if self.use_ada_layer_norm_single:\n            norm_hidden_states = self.norm2(hidden_states)\n            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp\n\n        if self._chunk_size is not None:\n            # \"feed_forward_chunk_size\" can be used to save memory\n            ff_output = _chunked_feed_forward(\n                self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale\n            )\n        else:\n            ff_output = self.ff(norm_hidden_states, scale=lora_scale)\n\n        if self.use_ada_layer_norm_zero:\n            ff_output = gate_mlp.unsqueeze(1) * ff_output\n        elif self.use_ada_layer_norm_single:\n            ff_output = gate_mlp * ff_output\n\n        hidden_states = ff_output + hidden_states\n        if hidden_states.ndim == 4:\n            hidden_states = hidden_states.squeeze(1)\n\n        return hidden_states\n\n\n@maybe_allow_in_graph\nclass TemporalBasicTransformerBlock(nn.Module):\n    r\"\"\"\n    A basic Transformer block for video like data.\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        time_mix_inner_dim (`int`): The number of channels for temporal attention.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        time_mix_inner_dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        cross_attention_dim: Optional[int] = None,\n    ):\n        super().__init__()\n        self.is_res = dim == time_mix_inner_dim\n\n        self.norm_in = nn.LayerNorm(dim)\n\n        # Define 3 blocks. Each block has its own normalization layer.\n        # 1. Self-Attn\n        self.norm_in = nn.LayerNorm(dim)\n        self.ff_in = FeedForward(\n            dim,\n            dim_out=time_mix_inner_dim,\n            activation_fn=\"geglu\",\n        )\n\n        self.norm1 = nn.LayerNorm(time_mix_inner_dim)\n        self.attn1 = Attention(\n            query_dim=time_mix_inner_dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            cross_attention_dim=None,\n        )\n\n        # 2. Cross-Attn\n        if cross_attention_dim is not None:\n            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.\n            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during\n            # the second cross attention block.\n            self.norm2 = nn.LayerNorm(time_mix_inner_dim)\n            self.attn2 = Attention(\n                query_dim=time_mix_inner_dim,\n                cross_attention_dim=cross_attention_dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n            )  # is self-attn if encoder_hidden_states is none\n        else:\n            self.norm2 = None\n            self.attn2 = None\n\n        # 3. Feed-forward\n        self.norm3 = nn.LayerNorm(time_mix_inner_dim)\n        self.ff = FeedForward(time_mix_inner_dim, activation_fn=\"geglu\")\n\n        # let chunk size default to None\n        self._chunk_size = None\n        self._chunk_dim = None\n\n    def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):\n        # Sets chunk feed-forward\n        self._chunk_size = chunk_size\n        # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off\n        self._chunk_dim = 1\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        num_frames: int,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        # Notice that normalization is always applied before the real computation in the following blocks.\n        # 0. Self-Attention\n        batch_size = hidden_states.shape[0]\n\n        batch_frames, seq_length, channels = hidden_states.shape\n        batch_size = batch_frames // num_frames\n\n        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)\n        hidden_states = hidden_states.permute(0, 2, 1, 3)\n        hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)\n\n        residual = hidden_states\n        hidden_states = self.norm_in(hidden_states)\n\n        if self._chunk_size is not None:\n            hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)\n        else:\n            hidden_states = self.ff_in(hidden_states)\n\n        if self.is_res:\n            hidden_states = hidden_states + residual\n\n        norm_hidden_states = self.norm1(hidden_states)\n        attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)\n        hidden_states = attn_output + hidden_states\n\n        # 3. Cross-Attention\n        if self.attn2 is not None:\n            norm_hidden_states = self.norm2(hidden_states)\n            attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)\n            hidden_states = attn_output + hidden_states\n\n        # 4. Feed-forward\n        norm_hidden_states = self.norm3(hidden_states)\n\n        if self._chunk_size is not None:\n            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)\n        else:\n            ff_output = self.ff(norm_hidden_states)\n\n        if self.is_res:\n            hidden_states = ff_output + hidden_states\n        else:\n            hidden_states = ff_output\n\n        hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)\n        hidden_states = hidden_states.permute(0, 2, 1, 3)\n        hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)\n\n        return hidden_states\n\n\nclass SkipFFTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        kv_input_dim: int,\n        kv_input_dim_proj_use_bias: bool,\n        dropout=0.0,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        attention_out_bias: bool = True,\n    ):\n        super().__init__()\n        if kv_input_dim != dim:\n            self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)\n        else:\n            self.kv_mapper = None\n\n        self.norm1 = RMSNorm(dim, 1e-06)\n\n        self.attn1 = Attention(\n            query_dim=dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            cross_attention_dim=cross_attention_dim,\n            out_bias=attention_out_bias,\n        )\n\n        self.norm2 = RMSNorm(dim, 1e-06)\n\n        self.attn2 = Attention(\n            query_dim=dim,\n            cross_attention_dim=cross_attention_dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            out_bias=attention_out_bias,\n        )\n\n    def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):\n        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}\n\n        if self.kv_mapper is not None:\n            encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))\n\n        norm_hidden_states = self.norm1(hidden_states)\n\n        attn_output = self.attn1(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            **cross_attention_kwargs,\n        )\n\n        hidden_states = attn_output + hidden_states\n\n        norm_hidden_states = self.norm2(hidden_states)\n\n        attn_output = self.attn2(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            **cross_attention_kwargs,\n        )\n\n        hidden_states = attn_output + hidden_states\n\n        return hidden_states\n\n\nclass FeedForward(nn.Module):\n    r\"\"\"\n    A feed-forward layer.\n\n    Parameters:\n        dim (`int`): The number of channels in the input.\n        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.\n        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.\n        bias (`bool`, defaults to True): Whether to use a bias in the linear layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dim_out: Optional[int] = None,\n        mult: int = 4,\n        dropout: float = 0.0,\n        activation_fn: str = \"geglu\",\n        final_dropout: bool = False,\n        inner_dim=None,\n        bias: bool = True,\n    ):\n        super().__init__()\n        if inner_dim is None:\n            inner_dim = int(dim * mult)\n        dim_out = dim_out if dim_out is not None else dim\n        linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear\n\n        if activation_fn == \"gelu\":\n            act_fn = GELU(dim, inner_dim, bias=bias)\n        if activation_fn == \"gelu-approximate\":\n            act_fn = GELU(dim, inner_dim, approximate=\"tanh\", bias=bias)\n        elif activation_fn == \"geglu\":\n            act_fn = GEGLU(dim, inner_dim, bias=bias)\n        elif activation_fn == \"geglu-approximate\":\n            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)\n\n        self.net = nn.ModuleList([])\n        # project in\n        self.net.append(act_fn)\n        # project dropout\n        self.net.append(nn.Dropout(dropout))\n        # project out\n        self.net.append(linear_cls(inner_dim, dim_out, bias=bias))\n        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout\n        if final_dropout:\n            self.net.append(nn.Dropout(dropout))\n\n    def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n        compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)\n        for module in self.net:\n            if isinstance(module, compatible_cls):\n                hidden_states = module(hidden_states, scale)\n            else:\n                hidden_states = module(hidden_states)\n        return hidden_states\n"
  },
  {
    "path": "foleycrafter/models/auffusion/attention_processor.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom importlib import import_module\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch import nn\n\nfrom diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer\nfrom diffusers.utils import USE_PEFT_BACKEND, deprecate, logging\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import maybe_allow_in_graph\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nif is_xformers_available():\n    import xformers\n    import xformers.ops\nelse:\n    xformers = None\n\n\n@maybe_allow_in_graph\nclass Attention(nn.Module):\n    r\"\"\"\n    A cross attention layer.\n\n    Parameters:\n        query_dim (`int`):\n            The number of channels in the query.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.\n        heads (`int`,  *optional*, defaults to 8):\n            The number of heads to use for multi-head attention.\n        dim_head (`int`,  *optional*, defaults to 64):\n            The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability to use.\n        bias (`bool`, *optional*, defaults to False):\n            Set to `True` for the query, key, and value linear layers to contain a bias parameter.\n        upcast_attention (`bool`, *optional*, defaults to False):\n            Set to `True` to upcast the attention computation to `float32`.\n        upcast_softmax (`bool`, *optional*, defaults to False):\n            Set to `True` to upcast the softmax computation to `float32`.\n        cross_attention_norm (`str`, *optional*, defaults to `None`):\n            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.\n        cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use for the group norm in the cross attention.\n        added_kv_proj_dim (`int`, *optional*, defaults to `None`):\n            The number of channels to use for the added key and value projections. If `None`, no projection is used.\n        norm_num_groups (`int`, *optional*, defaults to `None`):\n            The number of groups to use for the group norm in the attention.\n        spatial_norm_dim (`int`, *optional*, defaults to `None`):\n            The number of channels to use for the spatial normalization.\n        out_bias (`bool`, *optional*, defaults to `True`):\n            Set to `True` to use a bias in the output linear layer.\n        scale_qk (`bool`, *optional*, defaults to `True`):\n            Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.\n        only_cross_attention (`bool`, *optional*, defaults to `False`):\n            Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if\n            `added_kv_proj_dim` is not `None`.\n        eps (`float`, *optional*, defaults to 1e-5):\n            An additional value added to the denominator in group normalization that is used for numerical stability.\n        rescale_output_factor (`float`, *optional*, defaults to 1.0):\n            A factor to rescale the output by dividing it with this value.\n        residual_connection (`bool`, *optional*, defaults to `False`):\n            Set to `True` to add the residual connection to the output.\n        _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):\n            Set to `True` if the attention block is loaded from a deprecated state dict.\n        processor (`AttnProcessor`, *optional*, defaults to `None`):\n            The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and\n            `AttnProcessor` otherwise.\n    \"\"\"\n\n    def __init__(\n        self,\n        query_dim: int,\n        cross_attention_dim: Optional[int] = None,\n        heads: int = 8,\n        dim_head: int = 64,\n        dropout: float = 0.0,\n        bias: bool = False,\n        upcast_attention: bool = False,\n        upcast_softmax: bool = False,\n        cross_attention_norm: Optional[str] = None,\n        cross_attention_norm_num_groups: int = 32,\n        added_kv_proj_dim: Optional[int] = None,\n        norm_num_groups: Optional[int] = None,\n        spatial_norm_dim: Optional[int] = None,\n        out_bias: bool = True,\n        scale_qk: bool = True,\n        only_cross_attention: bool = False,\n        eps: float = 1e-5,\n        rescale_output_factor: float = 1.0,\n        residual_connection: bool = False,\n        _from_deprecated_attn_block: bool = False,\n        processor: Optional[\"AttnProcessor\"] = None,\n        out_dim: int = None,\n    ):\n        super().__init__()\n        self.inner_dim = out_dim if out_dim is not None else dim_head * heads\n        self.query_dim = query_dim\n        self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim\n        self.upcast_attention = upcast_attention\n        self.upcast_softmax = upcast_softmax\n        self.rescale_output_factor = rescale_output_factor\n        self.residual_connection = residual_connection\n        self.dropout = dropout\n        self.fused_projections = False\n        self.out_dim = out_dim if out_dim is not None else query_dim\n\n        # we make use of this private variable to know whether this class is loaded\n        # with an deprecated state dict so that we can convert it on the fly\n        self._from_deprecated_attn_block = _from_deprecated_attn_block\n\n        self.scale_qk = scale_qk\n        self.scale = dim_head**-0.5 if self.scale_qk else 1.0\n\n        self.heads = out_dim // dim_head if out_dim is not None else heads\n        # for slice_size > 0 the attention score computation\n        # is split across the batch axis to save memory\n        # You can set slice_size with `set_attention_slice`\n        self.sliceable_head_dim = heads\n\n        self.added_kv_proj_dim = added_kv_proj_dim\n        self.only_cross_attention = only_cross_attention\n\n        if self.added_kv_proj_dim is None and self.only_cross_attention:\n            raise ValueError(\n                \"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`.\"\n            )\n\n        if norm_num_groups is not None:\n            self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)\n        else:\n            self.group_norm = None\n\n        if spatial_norm_dim is not None:\n            self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)\n        else:\n            self.spatial_norm = None\n\n        if cross_attention_norm is None:\n            self.norm_cross = None\n        elif cross_attention_norm == \"layer_norm\":\n            self.norm_cross = nn.LayerNorm(self.cross_attention_dim)\n        elif cross_attention_norm == \"group_norm\":\n            if self.added_kv_proj_dim is not None:\n                # The given `encoder_hidden_states` are initially of shape\n                # (batch_size, seq_len, added_kv_proj_dim) before being projected\n                # to (batch_size, seq_len, cross_attention_dim). The norm is applied\n                # before the projection, so we need to use `added_kv_proj_dim` as\n                # the number of channels for the group norm.\n                norm_cross_num_channels = added_kv_proj_dim\n            else:\n                norm_cross_num_channels = self.cross_attention_dim\n\n            self.norm_cross = nn.GroupNorm(\n                num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True\n            )\n        else:\n            raise ValueError(\n                f\"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'\"\n            )\n\n        if USE_PEFT_BACKEND:\n            linear_cls = nn.Linear\n        else:\n            linear_cls = LoRACompatibleLinear\n\n        self.linear_cls = linear_cls\n        self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)\n\n        if not self.only_cross_attention:\n            # only relevant for the `AddedKVProcessor` classes\n            self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)\n            self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)\n        else:\n            self.to_k = None\n            self.to_v = None\n\n        if self.added_kv_proj_dim is not None:\n            self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)\n            self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)\n\n        self.to_out = nn.ModuleList([])\n        self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))\n        self.to_out.append(nn.Dropout(dropout))\n\n        # set attention processor\n        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n        if processor is None:\n            processor = (\n                AttnProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk else AttnProcessor()\n            )\n        self.set_processor(processor)\n\n    def set_use_memory_efficient_attention_xformers(\n        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None\n    ) -> None:\n        r\"\"\"\n        Set whether to use memory efficient attention from `xformers` or not.\n\n        Args:\n            use_memory_efficient_attention_xformers (`bool`):\n                Whether to use memory efficient attention from `xformers` or not.\n            attention_op (`Callable`, *optional*):\n                The attention operation to use. Defaults to `None` which uses the default attention operation from\n                `xformers`.\n        \"\"\"\n        is_lora = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            LORA_ATTENTION_PROCESSORS,\n        )\n        is_custom_diffusion = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),\n        )\n        is_added_kv_processor = hasattr(self, \"processor\") and isinstance(\n            self.processor,\n            (\n                AttnAddedKVProcessor,\n                AttnAddedKVProcessor2_0,\n                SlicedAttnAddedKVProcessor,\n                XFormersAttnAddedKVProcessor,\n                LoRAAttnAddedKVProcessor,\n            ),\n        )\n\n        if use_memory_efficient_attention_xformers:\n            if is_added_kv_processor and (is_lora or is_custom_diffusion):\n                raise NotImplementedError(\n                    f\"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}\"\n                )\n            if not is_xformers_available():\n                raise ModuleNotFoundError(\n                    (\n                        \"Refer to https://github.com/facebookresearch/xformers for more information on how to install\"\n                        \" xformers\"\n                    ),\n                    name=\"xformers\",\n                )\n            elif not torch.cuda.is_available():\n                raise ValueError(\n                    \"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is\"\n                    \" only available for GPU \"\n                )\n            else:\n                try:\n                    # Make sure we can run the memory efficient attention\n                    _ = xformers.ops.memory_efficient_attention(\n                        torch.randn((1, 2, 40), device=\"cuda\"),\n                        torch.randn((1, 2, 40), device=\"cuda\"),\n                        torch.randn((1, 2, 40), device=\"cuda\"),\n                    )\n                except Exception as e:\n                    raise e\n\n            if is_lora:\n                # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers\n                # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?\n                processor = LoRAXFormersAttnProcessor(\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                    rank=self.processor.rank,\n                    attention_op=attention_op,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                processor.to(self.processor.to_q_lora.up.weight.device)\n            elif is_custom_diffusion:\n                processor = CustomDiffusionXFormersAttnProcessor(\n                    train_kv=self.processor.train_kv,\n                    train_q_out=self.processor.train_q_out,\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                    attention_op=attention_op,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                if hasattr(self.processor, \"to_k_custom_diffusion\"):\n                    processor.to(self.processor.to_k_custom_diffusion.weight.device)\n            elif is_added_kv_processor:\n                # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP\n                # which uses this type of cross attention ONLY because the attention mask of format\n                # [0, ..., -10.000, ..., 0, ...,] is not supported\n                # throw warning\n                logger.info(\n                    \"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation.\"\n                )\n                processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)\n            else:\n                processor = XFormersAttnProcessor(attention_op=attention_op)\n        else:\n            if is_lora:\n                attn_processor_class = (\n                    LoRAAttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else LoRAAttnProcessor\n                )\n                processor = attn_processor_class(\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                    rank=self.processor.rank,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                processor.to(self.processor.to_q_lora.up.weight.device)\n            elif is_custom_diffusion:\n                attn_processor_class = (\n                    CustomDiffusionAttnProcessor2_0\n                    if hasattr(F, \"scaled_dot_product_attention\")\n                    else CustomDiffusionAttnProcessor\n                )\n                processor = attn_processor_class(\n                    train_kv=self.processor.train_kv,\n                    train_q_out=self.processor.train_q_out,\n                    hidden_size=self.processor.hidden_size,\n                    cross_attention_dim=self.processor.cross_attention_dim,\n                )\n                processor.load_state_dict(self.processor.state_dict())\n                if hasattr(self.processor, \"to_k_custom_diffusion\"):\n                    processor.to(self.processor.to_k_custom_diffusion.weight.device)\n            else:\n                # set attention processor\n                # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n                # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n                # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n                processor = (\n                    AttnProcessor2_0()\n                    if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk\n                    else AttnProcessor()\n                )\n\n        self.set_processor(processor)\n\n    def set_attention_slice(self, slice_size: int) -> None:\n        r\"\"\"\n        Set the slice size for attention computation.\n\n        Args:\n            slice_size (`int`):\n                The slice size for attention computation.\n        \"\"\"\n        if slice_size is not None and slice_size > self.sliceable_head_dim:\n            raise ValueError(f\"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.\")\n\n        if slice_size is not None and self.added_kv_proj_dim is not None:\n            processor = SlicedAttnAddedKVProcessor(slice_size)\n        elif slice_size is not None:\n            processor = SlicedAttnProcessor(slice_size)\n        elif self.added_kv_proj_dim is not None:\n            processor = AttnAddedKVProcessor()\n        else:\n            # set attention processor\n            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses\n            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention\n            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1\n            processor = (\n                AttnProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") and self.scale_qk else AttnProcessor()\n            )\n\n        self.set_processor(processor)\n\n    def set_processor(self, processor: \"AttnProcessor\", _remove_lora: bool = False) -> None:\n        r\"\"\"\n        Set the attention processor to use.\n\n        Args:\n            processor (`AttnProcessor`):\n                The attention processor to use.\n            _remove_lora (`bool`, *optional*, defaults to `False`):\n                Set to `True` to remove LoRA layers from the model.\n        \"\"\"\n        if not USE_PEFT_BACKEND and hasattr(self, \"processor\") and _remove_lora and self.to_q.lora_layer is not None:\n            deprecate(\n                \"set_processor to offload LoRA\",\n                \"0.26.0\",\n                \"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.\",\n            )\n            # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete\n            # We need to remove all LoRA layers\n            # Don't forget to remove ALL `_remove_lora` from the codebase\n            for module in self.modules():\n                if hasattr(module, \"set_lora_layer\"):\n                    module.set_lora_layer(None)\n\n        # if current processor is in `self._modules` and if passed `processor` is not, we need to\n        # pop `processor` from `self._modules`\n        if (\n            hasattr(self, \"processor\")\n            and isinstance(self.processor, torch.nn.Module)\n            and not isinstance(processor, torch.nn.Module)\n        ):\n            logger.info(f\"You are removing possibly trained weights of {self.processor} with {processor}\")\n            self._modules.pop(\"processor\")\n\n        self.processor = processor\n\n    def get_processor(self, return_deprecated_lora: bool = False) -> \"AttentionProcessor\":\n        r\"\"\"\n        Get the attention processor in use.\n\n        Args:\n            return_deprecated_lora (`bool`, *optional*, defaults to `False`):\n                Set to `True` to return the deprecated LoRA attention processor.\n\n        Returns:\n            \"AttentionProcessor\": The attention processor in use.\n        \"\"\"\n        if not return_deprecated_lora:\n            return self.processor\n\n        # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible\n        # serialization format for LoRA Attention Processors. It should be deleted once the integration\n        # with PEFT is completed.\n        is_lora_activated = {\n            name: module.lora_layer is not None\n            for name, module in self.named_modules()\n            if hasattr(module, \"lora_layer\")\n        }\n\n        # 1. if no layer has a LoRA activated we can return the processor as usual\n        if not any(is_lora_activated.values()):\n            return self.processor\n\n        # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`\n        is_lora_activated.pop(\"add_k_proj\", None)\n        is_lora_activated.pop(\"add_v_proj\", None)\n        # 2. else it is not possible that only some layers have LoRA activated\n        if not all(is_lora_activated.values()):\n            raise ValueError(\n                f\"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}\"\n            )\n\n        # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor\n        non_lora_processor_cls_name = self.processor.__class__.__name__\n        lora_processor_cls = getattr(import_module(__name__), \"LoRA\" + non_lora_processor_cls_name)\n\n        hidden_size = self.inner_dim\n\n        # now create a LoRA attention processor from the LoRA layers\n        if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:\n            kwargs = {\n                \"cross_attention_dim\": self.cross_attention_dim,\n                \"rank\": self.to_q.lora_layer.rank,\n                \"network_alpha\": self.to_q.lora_layer.network_alpha,\n                \"q_rank\": self.to_q.lora_layer.rank,\n                \"q_hidden_size\": self.to_q.lora_layer.out_features,\n                \"k_rank\": self.to_k.lora_layer.rank,\n                \"k_hidden_size\": self.to_k.lora_layer.out_features,\n                \"v_rank\": self.to_v.lora_layer.rank,\n                \"v_hidden_size\": self.to_v.lora_layer.out_features,\n                \"out_rank\": self.to_out[0].lora_layer.rank,\n                \"out_hidden_size\": self.to_out[0].lora_layer.out_features,\n            }\n\n            if hasattr(self.processor, \"attention_op\"):\n                kwargs[\"attention_op\"] = self.processor.attention_op\n\n            lora_processor = lora_processor_cls(hidden_size, **kwargs)\n            lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())\n            lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())\n            lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())\n            lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())\n        elif lora_processor_cls == LoRAAttnAddedKVProcessor:\n            lora_processor = lora_processor_cls(\n                hidden_size,\n                cross_attention_dim=self.add_k_proj.weight.shape[0],\n                rank=self.to_q.lora_layer.rank,\n                network_alpha=self.to_q.lora_layer.network_alpha,\n            )\n            lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())\n            lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())\n            lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())\n            lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())\n\n            # only save if used\n            if self.add_k_proj.lora_layer is not None:\n                lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())\n                lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())\n            else:\n                lora_processor.add_k_proj_lora = None\n                lora_processor.add_v_proj_lora = None\n        else:\n            raise ValueError(f\"{lora_processor_cls} does not exist.\")\n\n        return lora_processor\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        **cross_attention_kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        The forward method of the `Attention` class.\n\n        Args:\n            hidden_states (`torch.Tensor`):\n                The hidden states of the query.\n            encoder_hidden_states (`torch.Tensor`, *optional*):\n                The hidden states of the encoder.\n            attention_mask (`torch.Tensor`, *optional*):\n                The attention mask to use. If `None`, no mask is applied.\n            **cross_attention_kwargs:\n                Additional keyword arguments to pass along to the cross attention.\n\n        Returns:\n            `torch.Tensor`: The output of the attention layer.\n        \"\"\"\n        # The `Attention` class can call different attention processors / attention functions\n        # here we simply pass along all tensors to the selected processor class\n        # For standard processors that are defined here, `**cross_attention_kwargs` is empty\n        return self.processor(\n            self,\n            hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            attention_mask=attention_mask,\n            **cross_attention_kwargs,\n        )\n\n    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`\n        is the number of heads initialized while constructing the `Attention` class.\n\n        Args:\n            tensor (`torch.Tensor`): The tensor to reshape.\n\n        Returns:\n            `torch.Tensor`: The reshaped tensor.\n        \"\"\"\n        head_size = self.heads\n        batch_size, seq_len, dim = tensor.shape\n        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)\n        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)\n        return tensor\n\n    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:\n        r\"\"\"\n        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is\n        the number of heads initialized while constructing the `Attention` class.\n\n        Args:\n            tensor (`torch.Tensor`): The tensor to reshape.\n            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is\n                reshaped to `[batch_size * heads, seq_len, dim // heads]`.\n\n        Returns:\n            `torch.Tensor`: The reshaped tensor.\n        \"\"\"\n        head_size = self.heads\n        batch_size, seq_len, dim = tensor.shape\n        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)\n        tensor = tensor.permute(0, 2, 1, 3)\n\n        if out_dim == 3:\n            tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)\n\n        return tensor\n\n    def get_attention_scores(\n        self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None\n    ) -> torch.Tensor:\n        r\"\"\"\n        Compute the attention scores.\n\n        Args:\n            query (`torch.Tensor`): The query tensor.\n            key (`torch.Tensor`): The key tensor.\n            attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.\n\n        Returns:\n            `torch.Tensor`: The attention probabilities/scores.\n        \"\"\"\n        dtype = query.dtype\n        if self.upcast_attention:\n            query = query.float()\n            key = key.float()\n\n        if attention_mask is None:\n            baddbmm_input = torch.empty(\n                query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device\n            )\n            beta = 0\n        else:\n            baddbmm_input = attention_mask\n            beta = 1\n\n        attention_scores = torch.baddbmm(\n            baddbmm_input,\n            query,\n            key.transpose(-1, -2),\n            beta=beta,\n            alpha=self.scale,\n        )\n        del baddbmm_input\n\n        if self.upcast_softmax:\n            attention_scores = attention_scores.float()\n\n        attention_probs = attention_scores.softmax(dim=-1)\n        del attention_scores\n\n        attention_probs = attention_probs.to(dtype)\n\n        return attention_probs\n\n    def prepare_attention_mask(\n        self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepare the attention mask for the attention computation.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                The attention mask to prepare.\n            target_length (`int`):\n                The target length of the attention mask. This is the length of the attention mask after padding.\n            batch_size (`int`):\n                The batch size, which is used to repeat the attention mask.\n            out_dim (`int`, *optional*, defaults to `3`):\n                The output dimension of the attention mask. Can be either `3` or `4`.\n\n        Returns:\n            `torch.Tensor`: The prepared attention mask.\n        \"\"\"\n        head_size = self.heads\n        if attention_mask is None:\n            return attention_mask\n\n        current_length: int = attention_mask.shape[-1]\n        if current_length != target_length:\n            if attention_mask.device.type == \"mps\":\n                # HACK: MPS: Does not support padding by greater than dimension of input tensor.\n                # Instead, we can manually construct the padding tensor.\n                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)\n                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)\n                attention_mask = torch.cat([attention_mask, padding], dim=2)\n            else:\n                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:\n                #       we want to instead pad by (0, remaining_length), where remaining_length is:\n                #       remaining_length: int = target_length - current_length\n                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding\n                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)\n\n        if out_dim == 3:\n            if attention_mask.shape[0] < batch_size * head_size:\n                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)\n        elif out_dim == 4:\n            attention_mask = attention_mask.unsqueeze(1)\n            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)\n\n        return attention_mask\n\n    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the\n        `Attention` class.\n\n        Args:\n            encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.\n\n        Returns:\n            `torch.Tensor`: The normalized encoder hidden states.\n        \"\"\"\n        assert self.norm_cross is not None, \"self.norm_cross must be defined to call self.norm_encoder_hidden_states\"\n\n        if isinstance(self.norm_cross, nn.LayerNorm):\n            encoder_hidden_states = self.norm_cross(encoder_hidden_states)\n        elif isinstance(self.norm_cross, nn.GroupNorm):\n            # Group norm norms along the channels dimension and expects\n            # input to be in the shape of (N, C, *). In this case, we want\n            # to norm along the hidden dimension, so we need to move\n            # (batch_size, sequence_length, hidden_size) ->\n            # (batch_size, hidden_size, sequence_length)\n            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)\n            encoder_hidden_states = self.norm_cross(encoder_hidden_states)\n            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)\n        else:\n            assert False\n\n        return encoder_hidden_states\n\n    @torch.no_grad()\n    def fuse_projections(self, fuse=True):\n        is_cross_attention = self.cross_attention_dim != self.query_dim\n        device = self.to_q.weight.data.device\n        dtype = self.to_q.weight.data.dtype\n\n        if not is_cross_attention:\n            # fetch weight matrices.\n            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])\n            in_features = concatenated_weights.shape[1]\n            out_features = concatenated_weights.shape[0]\n\n            # create a new single projection layer and copy over the weights.\n            self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)\n            self.to_qkv.weight.copy_(concatenated_weights)\n\n        else:\n            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])\n            in_features = concatenated_weights.shape[1]\n            out_features = concatenated_weights.shape[0]\n\n            self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)\n            self.to_kv.weight.copy_(concatenated_weights)\n\n        self.fused_projections = fuse\n\n\nclass AttnProcessor:\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states, *args)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states, *args)\n        value = attn.to_v(encoder_hidden_states, *args)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass CustomDiffusionAttnProcessor(nn.Module):\n    r\"\"\"\n    Processor for implementing attention for the Custom Diffusion method.\n\n    Args:\n        train_kv (`bool`, defaults to `True`):\n            Whether to newly train the key and value matrices corresponding to the text features.\n        train_q_out (`bool`, defaults to `True`):\n            Whether to newly train query matrices corresponding to the latent image features.\n        hidden_size (`int`, *optional*, defaults to `None`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`, *optional*, defaults to `None`):\n            The number of channels in the `encoder_hidden_states`.\n        out_bias (`bool`, defaults to `True`):\n            Whether to include the bias parameter in `train_q_out`.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        train_kv: bool = True,\n        train_q_out: bool = True,\n        hidden_size: Optional[int] = None,\n        cross_attention_dim: Optional[int] = None,\n        out_bias: bool = True,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n        self.train_kv = train_kv\n        self.train_q_out = train_q_out\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n\n        # `_custom_diffusion` id for easy serialization and loading.\n        if self.train_kv:\n            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        if self.train_q_out:\n            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)\n            self.to_out_custom_diffusion = nn.ModuleList([])\n            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))\n            self.to_out_custom_diffusion.append(nn.Dropout(dropout))\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.Tensor:\n        batch_size, sequence_length, _ = hidden_states.shape\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n        if self.train_q_out:\n            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)\n        else:\n            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))\n\n        if encoder_hidden_states is None:\n            crossattn = False\n            encoder_hidden_states = hidden_states\n        else:\n            crossattn = True\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        if self.train_kv:\n            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))\n            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))\n            key = key.to(attn.to_q.weight.dtype)\n            value = value.to(attn.to_q.weight.dtype)\n        else:\n            key = attn.to_k(encoder_hidden_states)\n            value = attn.to_v(encoder_hidden_states)\n\n        if crossattn:\n            detach = torch.ones_like(key)\n            detach[:, :1, :] = detach[:, :1, :] * 0.0\n            key = detach * key + (1 - detach) * key.detach()\n            value = detach * value + (1 - detach) * value.detach()\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        if self.train_q_out:\n            # linear proj\n            hidden_states = self.to_out_custom_diffusion[0](hidden_states)\n            # dropout\n            hidden_states = self.to_out_custom_diffusion[1](hidden_states)\n        else:\n            # linear proj\n            hidden_states = attn.to_out[0](hidden_states)\n            # dropout\n            hidden_states = attn.to_out[1](hidden_states)\n\n        return hidden_states\n\n\nclass AttnAddedKVProcessor:\n    r\"\"\"\n    Processor for performing attention-related computations with extra learnable key and value matrices for the text\n    encoder.\n    \"\"\"\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n\n        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states, *args)\n        query = attn.head_to_batch_dim(query)\n\n        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)\n        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)\n        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)\n        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)\n\n        if not attn.only_cross_attention:\n            key = attn.to_k(hidden_states, *args)\n            value = attn.to_v(hidden_states, *args)\n            key = attn.head_to_batch_dim(key)\n            value = attn.head_to_batch_dim(value)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)\n        else:\n            key = encoder_hidden_states_key_proj\n            value = encoder_hidden_states_value_proj\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass AttnAddedKVProcessor2_0:\n    r\"\"\"\n    Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra\n    learnable key and value matrices for the text encoder.\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\n                \"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\"\n            )\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n\n        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states, *args)\n        query = attn.head_to_batch_dim(query, out_dim=4)\n\n        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)\n        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)\n        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)\n        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)\n\n        if not attn.only_cross_attention:\n            key = attn.to_k(hidden_states, *args)\n            value = attn.to_v(hidden_states, *args)\n            key = attn.head_to_batch_dim(key, out_dim=4)\n            value = attn.head_to_batch_dim(value, out_dim=4)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)\n        else:\n            key = encoder_hidden_states_key_proj\n            value = encoder_hidden_states_value_proj\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass XFormersAttnAddedKVProcessor:\n    r\"\"\"\n    Processor for implementing memory efficient attention using xFormers.\n\n    Args:\n        attention_op (`Callable`, *optional*, defaults to `None`):\n            The base\n            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to\n            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best\n            operator.\n    \"\"\"\n\n    def __init__(self, attention_op: Optional[Callable] = None):\n        self.attention_op = attention_op\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.Tensor:\n        residual = hidden_states\n        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n        query = attn.head_to_batch_dim(query)\n\n        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)\n        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)\n        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)\n        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)\n\n        if not attn.only_cross_attention:\n            key = attn.to_k(hidden_states)\n            value = attn.to_v(hidden_states)\n            key = attn.head_to_batch_dim(key)\n            value = attn.head_to_batch_dim(value)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)\n        else:\n            key = encoder_hidden_states_key_proj\n            value = encoder_hidden_states_value_proj\n\n        hidden_states = xformers.ops.memory_efficient_attention(\n            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale\n        )\n        hidden_states = hidden_states.to(query.dtype)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass XFormersAttnProcessor:\n    r\"\"\"\n    Processor for implementing memory efficient attention using xFormers.\n\n    Args:\n        attention_op (`Callable`, *optional*, defaults to `None`):\n            The base\n            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to\n            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best\n            operator.\n    \"\"\"\n\n    def __init__(self, attention_op: Optional[Callable] = None):\n        self.attention_op = attention_op\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        residual = hidden_states\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, key_tokens, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)\n        if attention_mask is not None:\n            # expand our mask's singleton query_tokens dimension:\n            #   [batch*heads,            1, key_tokens] ->\n            #   [batch*heads, query_tokens, key_tokens]\n            # so that it can be added as a bias onto the attention scores that xformers computes:\n            #   [batch*heads, query_tokens, key_tokens]\n            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.\n            _, query_tokens, _ = hidden_states.shape\n            attention_mask = attention_mask.expand(-1, query_tokens, -1)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states, *args)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states, *args)\n        value = attn.to_v(encoder_hidden_states, *args)\n\n        query = attn.head_to_batch_dim(query).contiguous()\n        key = attn.head_to_batch_dim(key).contiguous()\n        value = attn.head_to_batch_dim(value).contiguous()\n\n        hidden_states = xformers.ops.memory_efficient_attention(\n            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale\n        )\n        hidden_states = hidden_states.to(query.dtype)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass AttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n        **kwargs,\n    ) -> torch.FloatTensor:\n        residual = hidden_states\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n        query = attn.to_q(hidden_states, *args)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states, *args)\n        value = attn.to_v(encoder_hidden_states, *args)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass FusedAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,\n    key, value) are fused. For cross-attention modules, key and value projection matrices are fused.\n\n    <Tip warning={true}>\n\n    This API is currently 🧪 experimental in nature and can change in future.\n\n    </Tip>\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\n                \"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0.\"\n            )\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        residual = hidden_states\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n        if encoder_hidden_states is None:\n            qkv = attn.to_qkv(hidden_states, *args)\n            split_size = qkv.shape[-1] // 3\n            query, key, value = torch.split(qkv, split_size, dim=-1)\n        else:\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n            query = attn.to_q(hidden_states, *args)\n\n            kv = attn.to_kv(encoder_hidden_states, *args)\n            split_size = kv.shape[-1] // 2\n            key, value = torch.split(kv, split_size, dim=-1)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass CustomDiffusionXFormersAttnProcessor(nn.Module):\n    r\"\"\"\n    Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.\n\n    Args:\n    train_kv (`bool`, defaults to `True`):\n        Whether to newly train the key and value matrices corresponding to the text features.\n    train_q_out (`bool`, defaults to `True`):\n        Whether to newly train query matrices corresponding to the latent image features.\n    hidden_size (`int`, *optional*, defaults to `None`):\n        The hidden size of the attention layer.\n    cross_attention_dim (`int`, *optional*, defaults to `None`):\n        The number of channels in the `encoder_hidden_states`.\n    out_bias (`bool`, defaults to `True`):\n        Whether to include the bias parameter in `train_q_out`.\n    dropout (`float`, *optional*, defaults to 0.0):\n        The dropout probability to use.\n    attention_op (`Callable`, *optional*, defaults to `None`):\n        The base\n        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use\n        as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.\n    \"\"\"\n\n    def __init__(\n        self,\n        train_kv: bool = True,\n        train_q_out: bool = False,\n        hidden_size: Optional[int] = None,\n        cross_attention_dim: Optional[int] = None,\n        out_bias: bool = True,\n        dropout: float = 0.0,\n        attention_op: Optional[Callable] = None,\n    ):\n        super().__init__()\n        self.train_kv = train_kv\n        self.train_q_out = train_q_out\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.attention_op = attention_op\n\n        # `_custom_diffusion` id for easy serialization and loading.\n        if self.train_kv:\n            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        if self.train_q_out:\n            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)\n            self.to_out_custom_diffusion = nn.ModuleList([])\n            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))\n            self.to_out_custom_diffusion.append(nn.Dropout(dropout))\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if self.train_q_out:\n            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)\n        else:\n            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))\n\n        if encoder_hidden_states is None:\n            crossattn = False\n            encoder_hidden_states = hidden_states\n        else:\n            crossattn = True\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        if self.train_kv:\n            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))\n            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))\n            key = key.to(attn.to_q.weight.dtype)\n            value = value.to(attn.to_q.weight.dtype)\n        else:\n            key = attn.to_k(encoder_hidden_states)\n            value = attn.to_v(encoder_hidden_states)\n\n        if crossattn:\n            detach = torch.ones_like(key)\n            detach[:, :1, :] = detach[:, :1, :] * 0.0\n            key = detach * key + (1 - detach) * key.detach()\n            value = detach * value + (1 - detach) * value.detach()\n\n        query = attn.head_to_batch_dim(query).contiguous()\n        key = attn.head_to_batch_dim(key).contiguous()\n        value = attn.head_to_batch_dim(value).contiguous()\n\n        hidden_states = xformers.ops.memory_efficient_attention(\n            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale\n        )\n        hidden_states = hidden_states.to(query.dtype)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        if self.train_q_out:\n            # linear proj\n            hidden_states = self.to_out_custom_diffusion[0](hidden_states)\n            # dropout\n            hidden_states = self.to_out_custom_diffusion[1](hidden_states)\n        else:\n            # linear proj\n            hidden_states = attn.to_out[0](hidden_states)\n            # dropout\n            hidden_states = attn.to_out[1](hidden_states)\n\n        return hidden_states\n\n\nclass CustomDiffusionAttnProcessor2_0(nn.Module):\n    r\"\"\"\n    Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled\n    dot-product attention.\n\n    Args:\n        train_kv (`bool`, defaults to `True`):\n            Whether to newly train the key and value matrices corresponding to the text features.\n        train_q_out (`bool`, defaults to `True`):\n            Whether to newly train query matrices corresponding to the latent image features.\n        hidden_size (`int`, *optional*, defaults to `None`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`, *optional*, defaults to `None`):\n            The number of channels in the `encoder_hidden_states`.\n        out_bias (`bool`, defaults to `True`):\n            Whether to include the bias parameter in `train_q_out`.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        train_kv: bool = True,\n        train_q_out: bool = True,\n        hidden_size: Optional[int] = None,\n        cross_attention_dim: Optional[int] = None,\n        out_bias: bool = True,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n        self.train_kv = train_kv\n        self.train_q_out = train_q_out\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n\n        # `_custom_diffusion` id for easy serialization and loading.\n        if self.train_kv:\n            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        if self.train_q_out:\n            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)\n            self.to_out_custom_diffusion = nn.ModuleList([])\n            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))\n            self.to_out_custom_diffusion.append(nn.Dropout(dropout))\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        batch_size, sequence_length, _ = hidden_states.shape\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n        if self.train_q_out:\n            query = self.to_q_custom_diffusion(hidden_states)\n        else:\n            query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            crossattn = False\n            encoder_hidden_states = hidden_states\n        else:\n            crossattn = True\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        if self.train_kv:\n            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))\n            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))\n            key = key.to(attn.to_q.weight.dtype)\n            value = value.to(attn.to_q.weight.dtype)\n\n        else:\n            key = attn.to_k(encoder_hidden_states)\n            value = attn.to_v(encoder_hidden_states)\n\n        if crossattn:\n            detach = torch.ones_like(key)\n            detach[:, :1, :] = detach[:, :1, :] * 0.0\n            key = detach * key + (1 - detach) * key.detach()\n            value = detach * value + (1 - detach) * value.detach()\n\n        inner_dim = hidden_states.shape[-1]\n\n        head_dim = inner_dim // attn.heads\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if self.train_q_out:\n            # linear proj\n            hidden_states = self.to_out_custom_diffusion[0](hidden_states)\n            # dropout\n            hidden_states = self.to_out_custom_diffusion[1](hidden_states)\n        else:\n            # linear proj\n            hidden_states = attn.to_out[0](hidden_states)\n            # dropout\n            hidden_states = attn.to_out[1](hidden_states)\n\n        return hidden_states\n\n\nclass SlicedAttnProcessor:\n    r\"\"\"\n    Processor for implementing sliced attention.\n\n    Args:\n        slice_size (`int`, *optional*):\n            The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and\n            `attention_head_dim` must be a multiple of the `slice_size`.\n    \"\"\"\n\n    def __init__(self, slice_size: int):\n        self.slice_size = slice_size\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        residual = hidden_states\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n        dim = query.shape[-1]\n        query = attn.head_to_batch_dim(query)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        batch_size_attention, query_tokens, _ = query.shape\n        hidden_states = torch.zeros(\n            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype\n        )\n\n        for i in range(batch_size_attention // self.slice_size):\n            start_idx = i * self.slice_size\n            end_idx = (i + 1) * self.slice_size\n\n            query_slice = query[start_idx:end_idx]\n            key_slice = key[start_idx:end_idx]\n            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None\n\n            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)\n\n            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])\n\n            hidden_states[start_idx:end_idx] = attn_slice\n\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass SlicedAttnAddedKVProcessor:\n    r\"\"\"\n    Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.\n\n    Args:\n        slice_size (`int`, *optional*):\n            The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and\n            `attention_head_dim` must be a multiple of the `slice_size`.\n    \"\"\"\n\n    def __init__(self, slice_size):\n        self.slice_size = slice_size\n\n    def __call__(\n        self,\n        attn: \"Attention\",\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)\n\n        batch_size, sequence_length, _ = hidden_states.shape\n\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n        dim = query.shape[-1]\n        query = attn.head_to_batch_dim(query)\n\n        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)\n        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)\n\n        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)\n        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)\n\n        if not attn.only_cross_attention:\n            key = attn.to_k(hidden_states)\n            value = attn.to_v(hidden_states)\n            key = attn.head_to_batch_dim(key)\n            value = attn.head_to_batch_dim(value)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)\n        else:\n            key = encoder_hidden_states_key_proj\n            value = encoder_hidden_states_value_proj\n\n        batch_size_attention, query_tokens, _ = query.shape\n        hidden_states = torch.zeros(\n            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype\n        )\n\n        for i in range(batch_size_attention // self.slice_size):\n            start_idx = i * self.slice_size\n            end_idx = (i + 1) * self.slice_size\n\n            query_slice = query[start_idx:end_idx]\n            key_slice = key[start_idx:end_idx]\n            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None\n\n            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)\n\n            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])\n\n            hidden_states[start_idx:end_idx] = attn_slice\n\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass SpatialNorm(nn.Module):\n    \"\"\"\n    Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.\n\n    Args:\n        f_channels (`int`):\n            The number of channels for input to group normalization layer, and output of the spatial norm layer.\n        zq_channels (`int`):\n            The number of channels for the quantized vector as described in the paper.\n    \"\"\"\n\n    def __init__(\n        self,\n        f_channels: int,\n        zq_channels: int,\n    ):\n        super().__init__()\n        self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)\n        self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)\n        self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:\n        f_size = f.shape[-2:]\n        zq = F.interpolate(zq, size=f_size, mode=\"nearest\")\n        norm_f = self.norm_layer(f)\n        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)\n        return new_f\n\n\n## Deprecated\nclass LoRAAttnProcessor(nn.Module):\n    r\"\"\"\n    Processor for implementing the LoRA attention mechanism.\n\n    Args:\n        hidden_size (`int`, *optional*):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the `encoder_hidden_states`.\n        rank (`int`, defaults to 4):\n            The dimension of the LoRA update matrices.\n        network_alpha (`int`, *optional*):\n            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.\n        kwargs (`dict`):\n            Additional keyword arguments to pass to the `LoRALinearLayer` layers.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        cross_attention_dim: Optional[int] = None,\n        rank: int = 4,\n        network_alpha: Optional[int] = None,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.rank = rank\n\n        q_rank = kwargs.pop(\"q_rank\", None)\n        q_hidden_size = kwargs.pop(\"q_hidden_size\", None)\n        q_rank = q_rank if q_rank is not None else rank\n        q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size\n\n        v_rank = kwargs.pop(\"v_rank\", None)\n        v_hidden_size = kwargs.pop(\"v_hidden_size\", None)\n        v_rank = v_rank if v_rank is not None else rank\n        v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size\n\n        out_rank = kwargs.pop(\"out_rank\", None)\n        out_hidden_size = kwargs.pop(\"out_hidden_size\", None)\n        out_rank = out_rank if out_rank is not None else rank\n        out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size\n\n        self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)\n        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)\n        self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)\n\n    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:\n        self_cls_name = self.__class__.__name__\n        deprecate(\n            self_cls_name,\n            \"0.26.0\",\n            (\n                f\"Make sure use {self_cls_name[4:]} instead by setting\"\n                \"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using\"\n                \" `LoraLoaderMixin.load_lora_weights`\"\n            ),\n        )\n        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)\n        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)\n        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)\n        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)\n\n        attn._modules.pop(\"processor\")\n        attn.processor = AttnProcessor()\n        return attn.processor(attn, hidden_states, *args, **kwargs)\n\n\nclass LoRAAttnProcessor2_0(nn.Module):\n    r\"\"\"\n    Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product\n    attention.\n\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the `encoder_hidden_states`.\n        rank (`int`, defaults to 4):\n            The dimension of the LoRA update matrices.\n        network_alpha (`int`, *optional*):\n            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.\n        kwargs (`dict`):\n            Additional keyword arguments to pass to the `LoRALinearLayer` layers.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        cross_attention_dim: Optional[int] = None,\n        rank: int = 4,\n        network_alpha: Optional[int] = None,\n        **kwargs,\n    ):\n        super().__init__()\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.rank = rank\n\n        q_rank = kwargs.pop(\"q_rank\", None)\n        q_hidden_size = kwargs.pop(\"q_hidden_size\", None)\n        q_rank = q_rank if q_rank is not None else rank\n        q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size\n\n        v_rank = kwargs.pop(\"v_rank\", None)\n        v_hidden_size = kwargs.pop(\"v_hidden_size\", None)\n        v_rank = v_rank if v_rank is not None else rank\n        v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size\n\n        out_rank = kwargs.pop(\"out_rank\", None)\n        out_hidden_size = kwargs.pop(\"out_hidden_size\", None)\n        out_rank = out_rank if out_rank is not None else rank\n        out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size\n\n        self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)\n        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)\n        self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)\n\n    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:\n        self_cls_name = self.__class__.__name__\n        deprecate(\n            self_cls_name,\n            \"0.26.0\",\n            (\n                f\"Make sure use {self_cls_name[4:]} instead by setting\"\n                \"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using\"\n                \" `LoraLoaderMixin.load_lora_weights`\"\n            ),\n        )\n        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)\n        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)\n        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)\n        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)\n\n        attn._modules.pop(\"processor\")\n        attn.processor = AttnProcessor2_0()\n        return attn.processor(attn, hidden_states, *args, **kwargs)\n\n\nclass LoRAXFormersAttnProcessor(nn.Module):\n    r\"\"\"\n    Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.\n\n    Args:\n        hidden_size (`int`, *optional*):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`, *optional*):\n            The number of channels in the `encoder_hidden_states`.\n        rank (`int`, defaults to 4):\n            The dimension of the LoRA update matrices.\n        attention_op (`Callable`, *optional*, defaults to `None`):\n            The base\n            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to\n            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best\n            operator.\n        network_alpha (`int`, *optional*):\n            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.\n        kwargs (`dict`):\n            Additional keyword arguments to pass to the `LoRALinearLayer` layers.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        cross_attention_dim: int,\n        rank: int = 4,\n        attention_op: Optional[Callable] = None,\n        network_alpha: Optional[int] = None,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.rank = rank\n        self.attention_op = attention_op\n\n        q_rank = kwargs.pop(\"q_rank\", None)\n        q_hidden_size = kwargs.pop(\"q_hidden_size\", None)\n        q_rank = q_rank if q_rank is not None else rank\n        q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size\n\n        v_rank = kwargs.pop(\"v_rank\", None)\n        v_hidden_size = kwargs.pop(\"v_hidden_size\", None)\n        v_rank = v_rank if v_rank is not None else rank\n        v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size\n\n        out_rank = kwargs.pop(\"out_rank\", None)\n        out_hidden_size = kwargs.pop(\"out_hidden_size\", None)\n        out_rank = out_rank if out_rank is not None else rank\n        out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size\n\n        self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)\n        self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)\n        self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)\n\n    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:\n        self_cls_name = self.__class__.__name__\n        deprecate(\n            self_cls_name,\n            \"0.26.0\",\n            (\n                f\"Make sure use {self_cls_name[4:]} instead by setting\"\n                \"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using\"\n                \" `LoraLoaderMixin.load_lora_weights`\"\n            ),\n        )\n        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)\n        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)\n        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)\n        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)\n\n        attn._modules.pop(\"processor\")\n        attn.processor = XFormersAttnProcessor()\n        return attn.processor(attn, hidden_states, *args, **kwargs)\n\n\nclass LoRAAttnAddedKVProcessor(nn.Module):\n    r\"\"\"\n    Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text\n    encoder.\n\n    Args:\n        hidden_size (`int`, *optional*):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`, *optional*, defaults to `None`):\n            The number of channels in the `encoder_hidden_states`.\n        rank (`int`, defaults to 4):\n            The dimension of the LoRA update matrices.\n        network_alpha (`int`, *optional*):\n            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.\n        kwargs (`dict`):\n            Additional keyword arguments to pass to the `LoRALinearLayer` layers.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        cross_attention_dim: Optional[int] = None,\n        rank: int = 4,\n        network_alpha: Optional[int] = None,\n    ):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.rank = rank\n\n        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)\n        self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n        self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)\n\n    def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:\n        self_cls_name = self.__class__.__name__\n        deprecate(\n            self_cls_name,\n            \"0.26.0\",\n            (\n                f\"Make sure use {self_cls_name[4:]} instead by setting\"\n                \"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using\"\n                \" `LoraLoaderMixin.load_lora_weights`\"\n            ),\n        )\n        attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)\n        attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)\n        attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)\n        attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)\n\n        attn._modules.pop(\"processor\")\n        attn.processor = AttnAddedKVProcessor()\n        return attn.processor(attn, hidden_states, *args, **kwargs)\n\n\nclass IPAdapterAttnProcessor(nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater.\n\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        num_tokens (`int`, defaults to 4):\n            The context length of the image features.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.num_tokens = num_tokens\n        self.scale = scale\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        scale=1.0,\n    ):\n        if scale != 1.0:\n            logger.warning(\"`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.\")\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        # split hidden states\n        end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n        encoder_hidden_states, ip_hidden_states = (\n            encoder_hidden_states[:, :end_pos, :],\n            encoder_hidden_states[:, end_pos:, :],\n        )\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # for ip-adapter\n        ip_key = self.to_k_ip(ip_hidden_states)\n        ip_value = self.to_v_ip(ip_hidden_states)\n\n        ip_key = attn.head_to_batch_dim(ip_key)\n        ip_value = attn.head_to_batch_dim(ip_value)\n\n        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)\n        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)\n        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)\n\n        hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass VPTemporalAdapterAttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapter for PyTorch 2.0.\n\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):\n            The context length of the image features.\n        scale (`float` or `List[float]`, defaults to 1.0):\n            the weight scale of image prompt.\n    \"\"\"\n\n    \"\"\"\n    Support frame-wise VP-Adapter\n    encoder_hidden_states : I(num of ip_adapters), B, N * T(num of time condition), C\n    ip_adapter_masks(bool): (I, B, N * T, C) == encoder_hidden_states.shape\n\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):\n        super().__init__()\n\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\n                f\"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\"\n            )\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n\n        if not isinstance(num_tokens, (tuple, list)):\n            num_tokens = [num_tokens]\n        self.num_tokens = num_tokens\n\n        if not isinstance(scale, list):\n            scale = [scale] * len(num_tokens)\n        if len(scale) != len(num_tokens):\n            raise ValueError(\"`scale` should be a list of integers with the same length as `num_tokens`.\")\n        self.scale = scale\n\n        self.to_k_ip = nn.ModuleList(\n            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]\n        )\n        self.to_v_ip = nn.ModuleList(\n            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]\n        )\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n        ip_adapter_masks: Optional[torch.FloatTensor] = None,\n        time_conditions: Optional[list] = None,\n        audio_length_in_s: Optional[int] = None,\n    ):\n        residual = hidden_states\n\n        # separate ip_hidden_states from encoder_hidden_states\n        if encoder_hidden_states is not None:\n            if isinstance(encoder_hidden_states, tuple):\n                encoder_hidden_states, ip_hidden_states = encoder_hidden_states\n            else:\n                deprecation_message = (\n                    \"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release.\"\n                    \" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning.\"\n                )\n                deprecate(\"encoder_hidden_states not a tuple\", \"1.0.0\", deprecation_message, standard_warn=False)\n                end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]\n                encoder_hidden_states, ip_hidden_states = (\n                    encoder_hidden_states[:, :end_pos, :],\n                    [encoder_hidden_states[:, end_pos:, :]],\n                )\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if ip_adapter_masks is not None:\n            if not isinstance(ip_adapter_masks, List):\n                # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]\n                ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))\n            if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):\n                raise ValueError(\n                    f\"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match \"\n                    f\"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states \"\n                    f\"({len(ip_hidden_states)})\"\n                )\n            else:\n                for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):\n                    if not isinstance(mask, torch.Tensor) or mask.ndim != 4:\n                        raise ValueError(\n                            \"Each element of the ip_adapter_masks array should be a tensor with shape \"\n                            \"[1, num_images_for_ip_adapter, height, width].\"\n                            \" Please use `IPAdapterMaskProcessor` to preprocess your mask\"\n                        )\n                    if mask.shape[1] != ip_state.shape[1]:\n                        raise ValueError(\n                            f\"Number of masks ({mask.shape[1]}) does not match \"\n                            f\"number of ip images ({ip_state.shape[1]}) at index {index}\"\n                        )\n                    if isinstance(scale, list) and not len(scale) == mask.shape[1]:\n                        raise ValueError(\n                            f\"Number of masks ({mask.shape[1]}) does not match \"\n                            f\"number of scales ({len(scale)}) at index {index}\"\n                        )\n        else:\n            ip_adapter_masks = [None] * len(self.scale)\n        # for ip-adapter\n        for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(\n            ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks\n        ):\n            skip = False\n            if isinstance(scale, list):\n                if all(s == 0 for s in scale):\n                    skip = True\n            elif scale == 0:\n                skip = True\n            if not skip:\n                time_condition_masks = None\n                for time_condition in time_conditions:\n                    # hard code\n                    time_condition_mask = (\n                        torch.zeros(\n                            (\n                                batch_size,\n                                int(math.sqrt(hidden_states.shape[1]) // 2),\n                                int(2 * math.sqrt(hidden_states.shape[1])),\n                            )\n                        )\n                        .bool()\n                        .to(device=hidden_states.device)\n                    )\n                    mel_latent_length = time_condition_mask.shape[-1]\n                    time_start, time_end = (\n                        int(time_condition[0] // audio_length_in_s * mel_latent_length),\n                        int(time_condition[1] // audio_length_in_s * mel_latent_length),\n                    )\n\n                    time_condition_mask[:, :, time_start:time_end] = True\n                    time_condition_mask = time_condition_mask.flatten(-2).unsqueeze(-1).repeat(1, 1, 4)\n                    if time_condition_masks is None:\n                        time_condition_masks = time_condition_mask\n                    else:\n                        time_condition_masks = torch.cat([time_condition_masks, time_condition_mask], dim=-1)\n\n                current_ip_hidden_states = rearrange(current_ip_hidden_states, \"L B N C -> B (L N) C\")\n                ip_key = to_k_ip(current_ip_hidden_states)\n                ip_value = to_v_ip(current_ip_hidden_states)\n\n                ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n                ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n                time_condition_masks = time_condition_masks.unsqueeze(1).repeat(1, attn.heads, 1, 1)\n\n                # the output of sdp = (batch, num_heads, seq_len, head_dim)\n                # TODO: add support for attn.scale when we move to Torch 2.1\n                current_ip_hidden_states = F.scaled_dot_product_attention(\n                    query, ip_key, ip_value, attn_mask=time_condition_masks, dropout_p=0.0, is_causal=False\n                )\n\n                current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(\n                    batch_size, -1, attn.heads * head_dim\n                )\n                current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)\n\n                hidden_states = hidden_states + scale * current_ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAdapterAttnProcessor2_0(torch.nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapter for PyTorch 2.0.\n\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):\n            The context length of the image features.\n        scale (`float` or `List[float]`, defaults to 1.0):\n            the weight scale of image prompt.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):\n        super().__init__()\n\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\n                f\"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\"\n            )\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n\n        if not isinstance(num_tokens, (tuple, list)):\n            num_tokens = [num_tokens]\n        self.num_tokens = num_tokens\n\n        if not isinstance(scale, list):\n            scale = [scale] * len(num_tokens)\n        if len(scale) != len(num_tokens):\n            raise ValueError(\"`scale` should be a list of integers with the same length as `num_tokens`.\")\n        self.scale = scale\n        self.to_k_ip = nn.ModuleList(\n            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]\n        )\n        self.to_v_ip = nn.ModuleList(\n            [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]\n        )\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        temb: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n        ip_adapter_masks: Optional[torch.FloatTensor] = None,\n    ):\n        residual = hidden_states\n\n        # separate ip_hidden_states from encoder_hidden_states\n        if encoder_hidden_states is not None:\n            if isinstance(encoder_hidden_states, tuple):\n                encoder_hidden_states, ip_hidden_states = encoder_hidden_states\n            else:\n                deprecation_message = (\n                    \"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release.\"\n                    \" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning.\"\n                )\n                deprecate(\"encoder_hidden_states not a tuple\", \"1.0.0\", deprecation_message, standard_warn=False)\n                end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]\n                encoder_hidden_states, ip_hidden_states = (\n                    encoder_hidden_states[:, :end_pos, :],\n                    [encoder_hidden_states[:, end_pos:, :]],\n                )\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if ip_adapter_masks is not None:\n            if not isinstance(ip_adapter_masks, List):\n                # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]\n                ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))\n            if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):\n                raise ValueError(\n                    f\"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match \"\n                    f\"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states \"\n                    f\"({len(ip_hidden_states)})\"\n                )\n            else:\n                for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):\n                    if not isinstance(mask, torch.Tensor) or mask.ndim != 4:\n                        raise ValueError(\n                            \"Each element of the ip_adapter_masks array should be a tensor with shape \"\n                            \"[1, num_images_for_ip_adapter, height, width].\"\n                            \" Please use `IPAdapterMaskProcessor` to preprocess your mask\"\n                        )\n                    if mask.shape[1] != ip_state.shape[1]:\n                        raise ValueError(\n                            f\"Number of masks ({mask.shape[1]}) does not match \"\n                            f\"number of ip images ({ip_state.shape[1]}) at index {index}\"\n                        )\n                    if isinstance(scale, list) and not len(scale) == mask.shape[1]:\n                        raise ValueError(\n                            f\"Number of masks ({mask.shape[1]}) does not match \"\n                            f\"number of scales ({len(scale)}) at index {index}\"\n                        )\n        else:\n            ip_adapter_masks = [None] * len(self.scale)\n\n        # for ip-adapter\n        for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(\n            ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks\n        ):\n            skip = False\n            if isinstance(scale, list):\n                if all(s == 0 for s in scale):\n                    skip = True\n            elif scale == 0:\n                skip = True\n            if not skip:\n                ip_key = to_k_ip(current_ip_hidden_states)\n                ip_value = to_v_ip(current_ip_hidden_states)\n\n                ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n                ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n                # the output of sdp = (batch, num_heads, seq_len, head_dim)\n                # TODO: add support for attn.scale when we move to Torch 2.1\n                current_ip_hidden_states = F.scaled_dot_product_attention(\n                    query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False\n                )\n\n                current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(\n                    batch_size, -1, attn.heads * head_dim\n                )\n                current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)\n\n                hidden_states = hidden_states + scale * current_ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nLORA_ATTENTION_PROCESSORS = (\n    LoRAAttnProcessor,\n    LoRAAttnProcessor2_0,\n    LoRAXFormersAttnProcessor,\n    LoRAAttnAddedKVProcessor,\n)\n\nADDED_KV_ATTENTION_PROCESSORS = (\n    AttnAddedKVProcessor,\n    SlicedAttnAddedKVProcessor,\n    AttnAddedKVProcessor2_0,\n    XFormersAttnAddedKVProcessor,\n    LoRAAttnAddedKVProcessor,\n)\n\nCROSS_ATTENTION_PROCESSORS = (\n    AttnProcessor,\n    AttnProcessor2_0,\n    XFormersAttnProcessor,\n    SlicedAttnProcessor,\n    LoRAAttnProcessor,\n    LoRAAttnProcessor2_0,\n    LoRAXFormersAttnProcessor,\n    IPAdapterAttnProcessor,\n    IPAdapterAttnProcessor2_0,\n)\n\nAttentionProcessor = Union[\n    AttnProcessor,\n    AttnProcessor2_0,\n    FusedAttnProcessor2_0,\n    XFormersAttnProcessor,\n    SlicedAttnProcessor,\n    AttnAddedKVProcessor,\n    SlicedAttnAddedKVProcessor,\n    AttnAddedKVProcessor2_0,\n    XFormersAttnAddedKVProcessor,\n    CustomDiffusionAttnProcessor,\n    CustomDiffusionXFormersAttnProcessor,\n    CustomDiffusionAttnProcessor2_0,\n    # deprecated\n    LoRAAttnProcessor,\n    LoRAAttnProcessor2_0,\n    LoRAXFormersAttnProcessor,\n    LoRAAttnAddedKVProcessor,\n]\n"
  },
  {
    "path": "foleycrafter/models/auffusion/dual_transformer_2d.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional\n\nfrom torch import nn\n\nfrom foleycrafter.models.auffusion.transformer_2d import Transformer2DModel, Transformer2DModelOutput\n\n\nclass DualTransformer2DModel(nn.Module):\n    \"\"\"\n    Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            Pass if the input is continuous. The number of channels in the input and output.\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.\n        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.\n            Note that this is fixed at training time as it is used for learning a number of position embeddings. See\n            `ImagePositionalEmbeddings`.\n        num_vector_embeds (`int`, *optional*):\n            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.\n            Includes the class for the masked latent pixel.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.\n            The number of diffusion steps used during training. Note that this is fixed at training time as it is used\n            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for\n            up to but not more than steps than `num_embeds_ada_norm`.\n        attention_bias (`bool`, *optional*):\n            Configure if the TransformerBlocks' attention should contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        sample_size: Optional[int] = None,\n        num_vector_embeds: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n    ):\n        super().__init__()\n        self.transformers = nn.ModuleList(\n            [\n                Transformer2DModel(\n                    num_attention_heads=num_attention_heads,\n                    attention_head_dim=attention_head_dim,\n                    in_channels=in_channels,\n                    num_layers=num_layers,\n                    dropout=dropout,\n                    norm_num_groups=norm_num_groups,\n                    cross_attention_dim=cross_attention_dim,\n                    attention_bias=attention_bias,\n                    sample_size=sample_size,\n                    num_vector_embeds=num_vector_embeds,\n                    activation_fn=activation_fn,\n                    num_embeds_ada_norm=num_embeds_ada_norm,\n                )\n                for _ in range(2)\n            ]\n        )\n\n        # Variables that can be set by a pipeline:\n\n        # The ratio of transformer1 to transformer2's output states to be combined during inference\n        self.mix_ratio = 0.5\n\n        # The shape of `encoder_hidden_states` is expected to be\n        # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`\n        self.condition_lengths = [77, 257]\n\n        # Which transformer to use to encode which condition.\n        # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`\n        self.transformer_index_for_condition = [1, 0]\n\n    def forward(\n        self,\n        hidden_states,\n        encoder_hidden_states,\n        timestep=None,\n        attention_mask=None,\n        cross_attention_kwargs=None,\n        return_dict: bool = True,\n    ):\n        \"\"\"\n        Args:\n            hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.\n                When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input\n                hidden_states.\n            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.long`, *optional*):\n                Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.\n            attention_mask (`torch.FloatTensor`, *optional*):\n                Optional attention mask to be applied in Attention.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:\n            [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When\n            returning a tuple, the first element is the sample tensor.\n        \"\"\"\n        input_states = hidden_states\n\n        encoded_states = []\n        tokens_start = 0\n        # attention_mask is not used yet\n        for i in range(2):\n            # for each of the two transformers, pass the corresponding condition tokens\n            condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]\n            transformer_index = self.transformer_index_for_condition[i]\n            encoded_state = self.transformers[transformer_index](\n                input_states,\n                encoder_hidden_states=condition_state,\n                timestep=timestep,\n                cross_attention_kwargs=cross_attention_kwargs,\n                return_dict=False,\n            )[0]\n            encoded_states.append(encoded_state - input_states)\n            tokens_start += self.condition_lengths[i]\n\n        output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)\n        output_states = output_states + input_states\n\n        if not return_dict:\n            return (output_states,)\n\n        return Transformer2DModelOutput(sample=output_states)\n"
  },
  {
    "path": "foleycrafter/models/auffusion/loaders/ip_adapter.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Union\n\nimport torch\nfrom huggingface_hub.utils import validate_hf_hub_args\nfrom safetensors import safe_open\n\nfrom diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT\nfrom diffusers.utils import (\n    _get_model_file,\n    is_accelerate_available,\n    is_torch_version,\n    is_transformers_available,\n    logging,\n)\n\n\nif is_transformers_available():\n    from transformers import (\n        CLIPImageProcessor,\n        CLIPVisionModelWithProjection,\n    )\n\n    from diffusers.models.attention_processor import (\n        IPAdapterAttnProcessor,\n    )\n\nfrom foleycrafter.models.auffusion.attention_processor import (\n    IPAdapterAttnProcessor2_0,\n    VPTemporalAdapterAttnProcessor2_0,\n)\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass IPAdapterMixin:\n    \"\"\"Mixin for handling IP Adapters.\"\"\"\n\n    @validate_hf_hub_args\n    def load_ip_adapter(\n        self,\n        pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],\n        subfolder: Union[str, List[str]],\n        weight_name: Union[str, List[str]],\n        image_encoder_folder: Optional[str] = \"image_encoder\",\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters:\n            pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):\n                Can be either:\n\n                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved\n                      with [`ModelMixin.save_pretrained`].\n                    - A [torch state\n                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).\n            subfolder (`str` or `List[str]`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n                If a list is passed, it should have the same length as `weight_name`.\n            weight_name (`str` or `List[str]`):\n                The name of the weight file to load. If a list is passed, it should have the same length as\n                `weight_name`.\n            image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):\n                The subfolder location of the image encoder within a larger model repository on the Hub or locally.\n                Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,\n                you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder=\"image_encoder\"`.\n                If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,\n                for example, `image_encoder_folder=\"different_subfolder/image_encoder\"`.\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n        \"\"\"\n\n        # handle the list inputs for multiple IP Adapters\n        if not isinstance(weight_name, list):\n            weight_name = [weight_name]\n\n        if not isinstance(pretrained_model_name_or_path_or_dict, list):\n            pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]\n        if len(pretrained_model_name_or_path_or_dict) == 1:\n            pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)\n\n        if not isinstance(subfolder, list):\n            subfolder = [subfolder]\n        if len(subfolder) == 1:\n            subfolder = subfolder * len(weight_name)\n\n        if len(weight_name) != len(pretrained_model_name_or_path_or_dict):\n            raise ValueError(\"`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.\")\n\n        if len(weight_name) != len(subfolder):\n            raise ValueError(\"`weight_name` and `subfolder` must have the same length.\")\n\n        # Load the main state dict first.\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        force_download = kwargs.pop(\"force_download\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", None)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        low_cpu_mem_usage = kwargs.pop(\"low_cpu_mem_usage\", _LOW_CPU_MEM_USAGE_DEFAULT)\n\n        if low_cpu_mem_usage and not is_accelerate_available():\n            low_cpu_mem_usage = False\n            logger.warning(\n                \"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the\"\n                \" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install\"\n                \" `accelerate` for faster and less memory-intense model loading. You can do so with: \\n```\\npip\"\n                \" install accelerate\\n```\\n.\"\n            )\n\n        if low_cpu_mem_usage is True and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `low_cpu_mem_usage=False`.\"\n            )\n\n        user_agent = {\n            \"file_type\": \"attn_procs_weights\",\n            \"framework\": \"pytorch\",\n        }\n        state_dicts = []\n        for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(\n            pretrained_model_name_or_path_or_dict, weight_name, subfolder\n        ):\n            if not isinstance(pretrained_model_name_or_path_or_dict, dict):\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path_or_dict,\n                    weights_name=weight_name,\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    token=token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                )\n                if weight_name.endswith(\".safetensors\"):\n                    state_dict = {\"image_proj\": {}, \"ip_adapter\": {}}\n                    with safe_open(model_file, framework=\"pt\", device=\"cpu\") as f:\n                        for key in f.keys():\n                            if key.startswith(\"image_proj.\"):\n                                state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                            elif key.startswith(\"ip_adapter.\"):\n                                state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n                else:\n                    state_dict = torch.load(model_file, map_location=\"cpu\")\n            else:\n                state_dict = pretrained_model_name_or_path_or_dict\n\n            keys = list(state_dict.keys())\n            if keys != [\"image_proj\", \"ip_adapter\"]:\n                raise ValueError(\"Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.\")\n\n            state_dicts.append(state_dict)\n\n            # load CLIP image encoder here if it has not been registered to the pipeline yet\n            if hasattr(self, \"image_encoder\") and getattr(self, \"image_encoder\", None) is None:\n                if image_encoder_folder is not None:\n                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):\n                        logger.info(f\"loading image_encoder from {pretrained_model_name_or_path_or_dict}\")\n                        if image_encoder_folder.count(\"/\") == 0:\n                            image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()\n                        else:\n                            image_encoder_subfolder = Path(image_encoder_folder).as_posix()\n\n                        image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n                            pretrained_model_name_or_path_or_dict,\n                            subfolder=image_encoder_subfolder,\n                            low_cpu_mem_usage=low_cpu_mem_usage,\n                        ).to(self.device, dtype=self.dtype)\n                        self.register_modules(image_encoder=image_encoder)\n                    else:\n                        raise ValueError(\n                            \"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict.\"\n                        )\n                else:\n                    logger.warning(\n                        \"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter.\"\n                        \"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead.\"\n                    )\n\n            # create feature extractor if it has not been registered to the pipeline yet\n            if hasattr(self, \"feature_extractor\") and getattr(self, \"feature_extractor\", None) is None:\n                feature_extractor = CLIPImageProcessor()\n                self.register_modules(feature_extractor=feature_extractor)\n\n        # load ip-adapter into unet\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        unet._load_ip_adapter_weights(state_dicts)\n\n    def set_ip_adapter_scale(self, scale):\n        \"\"\"\n        Sets the conditioning scale between text and image.\n\n        Example:\n\n        ```py\n        pipeline.set_ip_adapter_scale(0.5)\n        ```\n        \"\"\"\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        for attn_processor in unet.attn_processors.values():\n            if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):\n                if not isinstance(scale, list):\n                    scale = [scale] * len(attn_processor.scale)\n                if len(attn_processor.scale) != len(scale):\n                    raise ValueError(\n                        f\"`scale` should be a list of same length as the number if ip-adapters \"\n                        f\"Expected {len(attn_processor.scale)} but got {len(scale)}.\"\n                    )\n                attn_processor.scale = scale\n\n    def unload_ip_adapter(self):\n        \"\"\"\n        Unloads the IP Adapter weights\n\n        Examples:\n\n        ```python\n        >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.\n        >>> pipeline.unload_ip_adapter()\n        >>> ...\n        ```\n        \"\"\"\n        # remove CLIP image encoder\n        if hasattr(self, \"image_encoder\") and getattr(self, \"image_encoder\", None) is not None:\n            self.image_encoder = None\n            self.register_to_config(image_encoder=[None, None])\n\n        # remove feature extractor only when safety_checker is None as safety_checker uses\n        # the feature_extractor later\n        if not hasattr(self, \"safety_checker\"):\n            if hasattr(self, \"feature_extractor\") and getattr(self, \"feature_extractor\", None) is not None:\n                self.feature_extractor = None\n                self.register_to_config(feature_extractor=[None, None])\n\n        # remove hidden encoder\n        self.unet.encoder_hid_proj = None\n        self.config.encoder_hid_dim_type = None\n\n        # restore original Unet attention processors layers\n        self.unet.set_default_attn_processor()\n\n\nclass VPAdapterMixin:\n    \"\"\"Mixin for handling IP Adapters.\"\"\"\n\n    @validate_hf_hub_args\n    def load_ip_adapter(\n        self,\n        pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],\n        subfolder: Union[str, List[str]],\n        weight_name: Union[str, List[str]],\n        image_encoder_folder: Optional[str] = \"image_encoder\",\n        **kwargs,\n    ):\n        \"\"\"\n        Parameters:\n            pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):\n                Can be either:\n\n                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved\n                      with [`ModelMixin.save_pretrained`].\n                    - A [torch state\n                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).\n            subfolder (`str` or `List[str]`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n                If a list is passed, it should have the same length as `weight_name`.\n            weight_name (`str` or `List[str]`):\n                The name of the weight file to load. If a list is passed, it should have the same length as\n                `weight_name`.\n            image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):\n                The subfolder location of the image encoder within a larger model repository on the Hub or locally.\n                Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,\n                you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder=\"image_encoder\"`.\n                If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,\n                for example, `image_encoder_folder=\"different_subfolder/image_encoder\"`.\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n        \"\"\"\n\n        # handle the list inputs for multiple IP Adapters\n        if not isinstance(weight_name, list):\n            weight_name = [weight_name]\n\n        if not isinstance(pretrained_model_name_or_path_or_dict, list):\n            pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]\n        if len(pretrained_model_name_or_path_or_dict) == 1:\n            pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)\n\n        if not isinstance(subfolder, list):\n            subfolder = [subfolder]\n        if len(subfolder) == 1:\n            subfolder = subfolder * len(weight_name)\n\n        if len(weight_name) != len(pretrained_model_name_or_path_or_dict):\n            raise ValueError(\"`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.\")\n\n        if len(weight_name) != len(subfolder):\n            raise ValueError(\"`weight_name` and `subfolder` must have the same length.\")\n\n        # Load the main state dict first.\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        force_download = kwargs.pop(\"force_download\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", None)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        low_cpu_mem_usage = kwargs.pop(\"low_cpu_mem_usage\", _LOW_CPU_MEM_USAGE_DEFAULT)\n\n        if low_cpu_mem_usage and not is_accelerate_available():\n            low_cpu_mem_usage = False\n            logger.warning(\n                \"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the\"\n                \" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install\"\n                \" `accelerate` for faster and less memory-intense model loading. You can do so with: \\n```\\npip\"\n                \" install accelerate\\n```\\n.\"\n            )\n\n        if low_cpu_mem_usage is True and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `low_cpu_mem_usage=False`.\"\n            )\n\n        user_agent = {\n            \"file_type\": \"attn_procs_weights\",\n            \"framework\": \"pytorch\",\n        }\n        state_dicts = []\n        for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(\n            pretrained_model_name_or_path_or_dict, weight_name, subfolder\n        ):\n            if not isinstance(pretrained_model_name_or_path_or_dict, dict):\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path_or_dict,\n                    weights_name=weight_name,\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    token=token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                )\n                if weight_name.endswith(\".safetensors\"):\n                    state_dict = {\"image_proj\": {}, \"ip_adapter\": {}}\n                    with safe_open(model_file, framework=\"pt\", device=\"cpu\") as f:\n                        for key in f.keys():\n                            if key.startswith(\"image_proj.\"):\n                                state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                            elif key.startswith(\"ip_adapter.\"):\n                                state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n                else:\n                    state_dict = torch.load(model_file, map_location=\"cpu\")\n            else:\n                state_dict = pretrained_model_name_or_path_or_dict\n\n            keys = list(state_dict.keys())\n            if keys != [\"image_proj\", \"ip_adapter\"]:\n                raise ValueError(\"Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.\")\n\n            state_dicts.append(state_dict)\n\n            # load CLIP image encoder here if it has not been registered to the pipeline yet\n            if hasattr(self, \"image_encoder\") and getattr(self, \"image_encoder\", None) is None:\n                if image_encoder_folder is not None:\n                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):\n                        logger.info(f\"loading image_encoder from {pretrained_model_name_or_path_or_dict}\")\n                        if image_encoder_folder.count(\"/\") == 0:\n                            image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()\n                        else:\n                            image_encoder_subfolder = Path(image_encoder_folder).as_posix()\n\n                        image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n                            pretrained_model_name_or_path_or_dict,\n                            subfolder=image_encoder_subfolder,\n                            low_cpu_mem_usage=low_cpu_mem_usage,\n                        ).to(self.device, dtype=self.dtype)\n                        self.register_modules(image_encoder=image_encoder)\n                    else:\n                        raise ValueError(\n                            \"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict.\"\n                        )\n                else:\n                    logger.warning(\n                        \"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter.\"\n                        \"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead.\"\n                    )\n\n            # create feature extractor if it has not been registered to the pipeline yet\n            if hasattr(self, \"feature_extractor\") and getattr(self, \"feature_extractor\", None) is None:\n                feature_extractor = CLIPImageProcessor()\n                self.register_modules(feature_extractor=feature_extractor)\n\n        # load ip-adapter into unet\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        unet._load_ip_adapter_weights_VPAdapter(state_dicts)\n\n    def set_ip_adapter_scale(self, scale):\n        \"\"\"\n        Sets the conditioning scale between text and image.\n\n        Example:\n\n        ```py\n        pipeline.set_ip_adapter_scale(0.5)\n        ```\n        \"\"\"\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        for attn_processor in unet.attn_processors.values():\n            if isinstance(attn_processor, (IPAdapterAttnProcessor, VPTemporalAdapterAttnProcessor2_0)):\n                if not isinstance(scale, list):\n                    scale = [scale] * len(attn_processor.scale)\n                if len(attn_processor.scale) != len(scale):\n                    raise ValueError(\n                        f\"`scale` should be a list of same length as the number if ip-adapters \"\n                        f\"Expected {len(attn_processor.scale)} but got {len(scale)}.\"\n                    )\n                attn_processor.scale = scale\n\n    def unload_ip_adapter(self):\n        \"\"\"\n        Unloads the IP Adapter weights\n\n        Examples:\n\n        ```python\n        >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.\n        >>> pipeline.unload_ip_adapter()\n        >>> ...\n        ```\n        \"\"\"\n        # remove CLIP image encoder\n        if hasattr(self, \"image_encoder\") and getattr(self, \"image_encoder\", None) is not None:\n            self.image_encoder = None\n            self.register_to_config(image_encoder=[None, None])\n\n        # remove feature extractor only when safety_checker is None as safety_checker uses\n        # the feature_extractor later\n        if not hasattr(self, \"safety_checker\"):\n            if hasattr(self, \"feature_extractor\") and getattr(self, \"feature_extractor\", None) is not None:\n                self.feature_extractor = None\n                self.register_to_config(feature_extractor=[None, None])\n\n        # remove hidden encoder\n        self.unet.encoder_hid_proj = None\n        self.config.encoder_hid_dim_type = None\n\n        # restore original Unet attention processors layers\n        self.unet.set_default_attn_processor()\n"
  },
  {
    "path": "foleycrafter/models/auffusion/loaders/unet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom collections import defaultdict\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport safetensors\nimport torch\nimport torch.nn.functional as F\nfrom huggingface_hub.utils import validate_hf_hub_args\nfrom torch import nn\n\nfrom diffusers.loaders.utils import AttnProcsLayers\nfrom diffusers.models.embeddings import ImageProjection\nfrom diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    _get_model_file,\n    delete_adapter_layers,\n    is_accelerate_available,\n    is_torch_version,\n    logging,\n    set_adapter_layers,\n    set_weights_and_activate_adapters,\n)\nfrom foleycrafter.models.auffusion.attention_processor import (\n    AttnProcessor2_0,\n    IPAdapterAttnProcessor2_0,\n    VPTemporalAdapterAttnProcessor2_0,\n)\n\n\nif is_accelerate_available():\n    from accelerate import init_empty_weights\n    from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module\n\nlogger = logging.get_logger(__name__)\n\n\nclass VPAdapterImageProjection(nn.Module):\n    def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):\n        super().__init__()\n        self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)\n\n    def forward(self, image_embeds: List[torch.FloatTensor]):\n        projected_image_embeds = []\n\n        # currently, we accept `image_embeds` as\n        #  1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]\n        #  2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]\n        if not isinstance(image_embeds, list):\n            # deprecation_message = (\n            #     \"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release.\"\n            #     \" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning.\"\n            # )\n            image_embeds = [image_embeds.unsqueeze(1)]\n\n        if len(image_embeds) != len(self.image_projection_layers):\n            raise ValueError(\n                f\"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}\"\n            )\n\n        for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):\n            image_embed = image_embed.squeeze(1)\n            batch_size, num_images = image_embed.shape[0], image_embed.shape[1]\n            image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])\n            image_embed = image_projection_layer(image_embed)\n            image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])\n\n            projected_image_embeds.append(image_embed)\n\n        return projected_image_embeds\n\n\nclass MultiIPAdapterImageProjection(nn.Module):\n    def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):\n        super().__init__()\n        self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)\n\n    def forward(self, image_embeds: List[torch.FloatTensor]):\n        projected_image_embeds = []\n\n        # currently, we accept `image_embeds` as\n        #  1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]\n        #  2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]\n        if not isinstance(image_embeds, list):\n            # deprecation_message = (\n            #     \"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release.\"\n            #     \" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning.\"\n            # )\n            image_embeds = [image_embeds.unsqueeze(1)]\n\n        if len(image_embeds) != len(self.image_projection_layers):\n            raise ValueError(\n                f\"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}\"\n            )\n\n        for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):\n            batch_size, num_images = image_embed.shape[0], image_embed.shape[1]\n            image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])\n            image_embed = image_projection_layer(image_embed)\n            image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])\n\n            projected_image_embeds.append(image_embed)\n\n        return projected_image_embeds\n\n\nTEXT_ENCODER_NAME = \"text_encoder\"\nUNET_NAME = \"unet\"\n\nLORA_WEIGHT_NAME = \"pytorch_lora_weights.bin\"\nLORA_WEIGHT_NAME_SAFE = \"pytorch_lora_weights.safetensors\"\n\nCUSTOM_DIFFUSION_WEIGHT_NAME = \"pytorch_custom_diffusion_weights.bin\"\nCUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = \"pytorch_custom_diffusion_weights.safetensors\"\n\n\nclass UNet2DConditionLoadersMixin:\n    \"\"\"\n    Load LoRA layers into a [`UNet2DCondtionModel`].\n    \"\"\"\n\n    text_encoder_name = TEXT_ENCODER_NAME\n    unet_name = UNET_NAME\n\n    @validate_hf_hub_args\n    def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):\n        r\"\"\"\n        Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be\n        defined in\n        [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)\n        and be a `torch.nn.Module` class.\n\n        Parameters:\n            pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):\n                Can be either:\n\n                    - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on\n                      the Hub.\n                    - A path to a directory (for example `./my_model_directory`) containing the model weights saved\n                      with [`ModelMixin.save_pretrained`].\n                    - A [torch state\n                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).\n\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            resume_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any\n                incompletely downloaded files are deleted.\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            subfolder (`str`, *optional*, defaults to `\"\"`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n\n        Example:\n\n        ```py\n        from diffusers import AutoPipelineForText2Image\n        import torch\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n        ).to(\"cuda\")\n        pipeline.unet.load_attn_procs(\n            \"jbilcke-hf/sdxl-cinematic-1\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_name=\"cinematic\"\n        )\n        ```\n        \"\"\"\n        from diffusers.models.attention_processor import CustomDiffusionAttnProcessor\n        from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer\n\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        force_download = kwargs.pop(\"force_download\", False)\n        resume_download = kwargs.pop(\"resume_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", None)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n        weight_name = kwargs.pop(\"weight_name\", None)\n        use_safetensors = kwargs.pop(\"use_safetensors\", None)\n        low_cpu_mem_usage = kwargs.pop(\"low_cpu_mem_usage\", _LOW_CPU_MEM_USAGE_DEFAULT)\n        # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.\n        # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning\n        network_alphas = kwargs.pop(\"network_alphas\", None)\n\n        _pipeline = kwargs.pop(\"_pipeline\", None)\n\n        is_network_alphas_none = network_alphas is None\n\n        allow_pickle = False\n\n        if use_safetensors is None:\n            use_safetensors = True\n            allow_pickle = True\n\n        user_agent = {\n            \"file_type\": \"attn_procs_weights\",\n            \"framework\": \"pytorch\",\n        }\n\n        if low_cpu_mem_usage and not is_accelerate_available():\n            low_cpu_mem_usage = False\n            logger.warning(\n                \"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the\"\n                \" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install\"\n                \" `accelerate` for faster and less memory-intense model loading. You can do so with: \\n```\\npip\"\n                \" install accelerate\\n```\\n.\"\n            )\n\n        model_file = None\n        if not isinstance(pretrained_model_name_or_path_or_dict, dict):\n            # Let's first try to load .safetensors weights\n            if (use_safetensors and weight_name is None) or (\n                weight_name is not None and weight_name.endswith(\".safetensors\")\n            ):\n                try:\n                    model_file = _get_model_file(\n                        pretrained_model_name_or_path_or_dict,\n                        weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,\n                        cache_dir=cache_dir,\n                        force_download=force_download,\n                        resume_download=resume_download,\n                        proxies=proxies,\n                        local_files_only=local_files_only,\n                        token=token,\n                        revision=revision,\n                        subfolder=subfolder,\n                        user_agent=user_agent,\n                    )\n                    state_dict = safetensors.torch.load_file(model_file, device=\"cpu\")\n                except IOError as e:\n                    if not allow_pickle:\n                        raise e\n                    # try loading non-safetensors weights\n                    pass\n            if model_file is None:\n                model_file = _get_model_file(\n                    pretrained_model_name_or_path_or_dict,\n                    weights_name=weight_name or LORA_WEIGHT_NAME,\n                    cache_dir=cache_dir,\n                    force_download=force_download,\n                    resume_download=resume_download,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    token=token,\n                    revision=revision,\n                    subfolder=subfolder,\n                    user_agent=user_agent,\n                )\n                state_dict = torch.load(model_file, map_location=\"cpu\")\n        else:\n            state_dict = pretrained_model_name_or_path_or_dict\n\n        # fill attn processors\n        lora_layers_list = []\n\n        is_lora = all((\"lora\" in k or k.endswith(\".alpha\")) for k in state_dict.keys()) and not USE_PEFT_BACKEND\n        is_custom_diffusion = any(\"custom_diffusion\" in k for k in state_dict.keys())\n\n        if is_lora:\n            # correct keys\n            state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)\n\n            if network_alphas is not None:\n                network_alphas_keys = list(network_alphas.keys())\n                used_network_alphas_keys = set()\n\n            lora_grouped_dict = defaultdict(dict)\n            mapped_network_alphas = {}\n\n            all_keys = list(state_dict.keys())\n            for key in all_keys:\n                value = state_dict.pop(key)\n                attn_processor_key, sub_key = \".\".join(key.split(\".\")[:-3]), \".\".join(key.split(\".\")[-3:])\n                lora_grouped_dict[attn_processor_key][sub_key] = value\n\n                # Create another `mapped_network_alphas` dictionary so that we can properly map them.\n                if network_alphas is not None:\n                    for k in network_alphas_keys:\n                        if k.replace(\".alpha\", \"\") in key:\n                            mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})\n                            used_network_alphas_keys.add(k)\n\n            if not is_network_alphas_none:\n                if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:\n                    raise ValueError(\n                        f\"The `network_alphas` has to be empty at this point but has the following keys \\n\\n {', '.join(network_alphas.keys())}\"\n                    )\n\n            if len(state_dict) > 0:\n                raise ValueError(\n                    f\"The `state_dict` has to be empty at this point but has the following keys \\n\\n {', '.join(state_dict.keys())}\"\n                )\n\n            for key, value_dict in lora_grouped_dict.items():\n                attn_processor = self\n                for sub_key in key.split(\".\"):\n                    attn_processor = getattr(attn_processor, sub_key)\n\n                # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers\n                # or add_{k,v,q,out_proj}_proj_lora layers.\n                rank = value_dict[\"lora.down.weight\"].shape[0]\n\n                if isinstance(attn_processor, LoRACompatibleConv):\n                    in_features = attn_processor.in_channels\n                    out_features = attn_processor.out_channels\n                    kernel_size = attn_processor.kernel_size\n\n                    ctx = init_empty_weights if low_cpu_mem_usage else nullcontext\n                    with ctx():\n                        lora = LoRAConv2dLayer(\n                            in_features=in_features,\n                            out_features=out_features,\n                            rank=rank,\n                            kernel_size=kernel_size,\n                            stride=attn_processor.stride,\n                            padding=attn_processor.padding,\n                            network_alpha=mapped_network_alphas.get(key),\n                        )\n                elif isinstance(attn_processor, LoRACompatibleLinear):\n                    ctx = init_empty_weights if low_cpu_mem_usage else nullcontext\n                    with ctx():\n                        lora = LoRALinearLayer(\n                            attn_processor.in_features,\n                            attn_processor.out_features,\n                            rank,\n                            mapped_network_alphas.get(key),\n                        )\n                else:\n                    raise ValueError(f\"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.\")\n\n                value_dict = {k.replace(\"lora.\", \"\"): v for k, v in value_dict.items()}\n                lora_layers_list.append((attn_processor, lora))\n\n                if low_cpu_mem_usage:\n                    device = next(iter(value_dict.values())).device\n                    dtype = next(iter(value_dict.values())).dtype\n                    load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)\n                else:\n                    lora.load_state_dict(value_dict)\n\n        elif is_custom_diffusion:\n            attn_processors = {}\n            custom_diffusion_grouped_dict = defaultdict(dict)\n            for key, value in state_dict.items():\n                if len(value) == 0:\n                    custom_diffusion_grouped_dict[key] = {}\n                else:\n                    if \"to_out\" in key:\n                        attn_processor_key, sub_key = \".\".join(key.split(\".\")[:-3]), \".\".join(key.split(\".\")[-3:])\n                    else:\n                        attn_processor_key, sub_key = \".\".join(key.split(\".\")[:-2]), \".\".join(key.split(\".\")[-2:])\n                    custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value\n\n            for key, value_dict in custom_diffusion_grouped_dict.items():\n                if len(value_dict) == 0:\n                    attn_processors[key] = CustomDiffusionAttnProcessor(\n                        train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None\n                    )\n                else:\n                    cross_attention_dim = value_dict[\"to_k_custom_diffusion.weight\"].shape[1]\n                    hidden_size = value_dict[\"to_k_custom_diffusion.weight\"].shape[0]\n                    train_q_out = True if \"to_q_custom_diffusion.weight\" in value_dict else False\n                    attn_processors[key] = CustomDiffusionAttnProcessor(\n                        train_kv=True,\n                        train_q_out=train_q_out,\n                        hidden_size=hidden_size,\n                        cross_attention_dim=cross_attention_dim,\n                    )\n                    attn_processors[key].load_state_dict(value_dict)\n        elif USE_PEFT_BACKEND:\n            # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`\n            # on the Unet\n            pass\n        else:\n            raise ValueError(\n                f\"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training.\"\n            )\n\n        # <Unsafe code\n        # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype\n        # Now we remove any existing hooks to\n        is_model_cpu_offload = False\n        is_sequential_cpu_offload = False\n\n        # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`\n        if not USE_PEFT_BACKEND:\n            if _pipeline is not None:\n                for _, component in _pipeline.components.items():\n                    if isinstance(component, nn.Module) and hasattr(component, \"_hf_hook\"):\n                        is_model_cpu_offload = isinstance(getattr(component, \"_hf_hook\"), CpuOffload)\n                        is_sequential_cpu_offload = isinstance(getattr(component, \"_hf_hook\"), AlignDevicesHook)\n\n                        logger.info(\n                            \"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again.\"\n                        )\n                        remove_hook_from_module(component, recurse=is_sequential_cpu_offload)\n\n            # only custom diffusion needs to set attn processors\n            if is_custom_diffusion:\n                self.set_attn_processor(attn_processors)\n\n            # set lora layers\n            for target_module, lora_layer in lora_layers_list:\n                target_module.set_lora_layer(lora_layer)\n\n            self.to(dtype=self.dtype, device=self.device)\n\n            # Offload back.\n            if is_model_cpu_offload:\n                _pipeline.enable_model_cpu_offload()\n            elif is_sequential_cpu_offload:\n                _pipeline.enable_sequential_cpu_offload()\n            # Unsafe code />\n\n    def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):\n        is_new_lora_format = all(\n            key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()\n        )\n        if is_new_lora_format:\n            # Strip the `\"unet\"` prefix.\n            is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())\n            if is_text_encoder_present:\n                warn_message = \"The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights).\"\n                logger.warn(warn_message)\n            unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]\n            state_dict = {k.replace(f\"{self.unet_name}.\", \"\"): v for k, v in state_dict.items() if k in unet_keys}\n\n        # change processor format to 'pure' LoRACompatibleLinear format\n        if any(\"processor\" in k.split(\".\") for k in state_dict.keys()):\n\n            def format_to_lora_compatible(key):\n                if \"processor\" not in key.split(\".\"):\n                    return key\n                return key.replace(\".processor\", \"\").replace(\"to_out_lora\", \"to_out.0.lora\").replace(\"_lora\", \".lora\")\n\n            state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}\n\n            if network_alphas is not None:\n                network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}\n        return state_dict, network_alphas\n\n    def save_attn_procs(\n        self,\n        save_directory: Union[str, os.PathLike],\n        is_main_process: bool = True,\n        weight_name: str = None,\n        save_function: Callable = None,\n        safe_serialization: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Save attention processor layers to a directory so that it can be reloaded with the\n        [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.\n\n        Arguments:\n            save_directory (`str` or `os.PathLike`):\n                Directory to save an attention processor to (will be created if it doesn't exist).\n            is_main_process (`bool`, *optional*, defaults to `True`):\n                Whether the process calling this is the main process or not. Useful during distributed training and you\n                need to call this function on all processes. In this case, set `is_main_process=True` only on the main\n                process to avoid race conditions.\n            save_function (`Callable`):\n                The function to use to save the state dictionary. Useful during distributed training when you need to\n                replace `torch.save` with another method. Can be configured with the environment variable\n                `DIFFUSERS_SAVE_MODE`.\n            safe_serialization (`bool`, *optional*, defaults to `True`):\n                Whether to save the model using `safetensors` or with `pickle`.\n\n        Example:\n\n        ```py\n        import torch\n        from diffusers import DiffusionPipeline\n\n        pipeline = DiffusionPipeline.from_pretrained(\n            \"CompVis/stable-diffusion-v1-4\",\n            torch_dtype=torch.float16,\n        ).to(\"cuda\")\n        pipeline.unet.load_attn_procs(\"path-to-save-model\", weight_name=\"pytorch_custom_diffusion_weights.bin\")\n        pipeline.unet.save_attn_procs(\"path-to-save-model\", weight_name=\"pytorch_custom_diffusion_weights.bin\")\n        ```\n        \"\"\"\n        from diffusers.models.attention_processor import (\n            CustomDiffusionAttnProcessor,\n            CustomDiffusionAttnProcessor2_0,\n            CustomDiffusionXFormersAttnProcessor,\n        )\n\n        if os.path.isfile(save_directory):\n            logger.error(f\"Provided path ({save_directory}) should be a directory, not a file\")\n            return\n\n        if save_function is None:\n            if safe_serialization:\n\n                def save_function(weights, filename):\n                    return safetensors.torch.save_file(weights, filename, metadata={\"format\": \"pt\"})\n\n            else:\n                save_function = torch.save\n\n        os.makedirs(save_directory, exist_ok=True)\n\n        is_custom_diffusion = any(\n            isinstance(\n                x,\n                (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),\n            )\n            for (_, x) in self.attn_processors.items()\n        )\n        if is_custom_diffusion:\n            model_to_save = AttnProcsLayers(\n                {\n                    y: x\n                    for (y, x) in self.attn_processors.items()\n                    if isinstance(\n                        x,\n                        (\n                            CustomDiffusionAttnProcessor,\n                            CustomDiffusionAttnProcessor2_0,\n                            CustomDiffusionXFormersAttnProcessor,\n                        ),\n                    )\n                }\n            )\n            state_dict = model_to_save.state_dict()\n            for name, attn in self.attn_processors.items():\n                if len(attn.state_dict()) == 0:\n                    state_dict[name] = {}\n        else:\n            model_to_save = AttnProcsLayers(self.attn_processors)\n            state_dict = model_to_save.state_dict()\n\n        if weight_name is None:\n            if safe_serialization:\n                weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE\n            else:\n                weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME\n\n        # Save the model\n        save_function(state_dict, os.path.join(save_directory, weight_name))\n        logger.info(f\"Model weights saved in {os.path.join(save_directory, weight_name)}\")\n\n    def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):\n        self.lora_scale = lora_scale\n        self._safe_fusing = safe_fusing\n        self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))\n\n    def _fuse_lora_apply(self, module, adapter_names=None):\n        if not USE_PEFT_BACKEND:\n            if hasattr(module, \"_fuse_lora\"):\n                module._fuse_lora(self.lora_scale, self._safe_fusing)\n\n            if adapter_names is not None:\n                raise ValueError(\n                    \"The `adapter_names` argument is not supported in your environment. Please switch\"\n                    \" to PEFT backend to use this argument by installing latest PEFT and transformers.\"\n                    \" `pip install -U peft transformers`\"\n                )\n\n    def unfuse_lora(self):\n        self.apply(self._unfuse_lora_apply)\n\n    def _unfuse_lora_apply(self, module):\n        if not USE_PEFT_BACKEND:\n            if hasattr(module, \"_unfuse_lora\"):\n                module._unfuse_lora()\n\n    def set_adapters(\n        self,\n        adapter_names: Union[List[str], str],\n        weights: Optional[Union[List[float], float]] = None,\n    ):\n        \"\"\"\n        Set the currently active adapters for use in the UNet.\n\n        Args:\n            adapter_names (`List[str]` or `str`):\n                The names of the adapters to use.\n            adapter_weights (`Union[List[float], float]`, *optional*):\n                The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the\n                adapters.\n\n        Example:\n\n        ```py\n        from diffusers import AutoPipelineForText2Image\n        import torch\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n        ).to(\"cuda\")\n        pipeline.load_lora_weights(\n            \"jbilcke-hf/sdxl-cinematic-1\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_name=\"cinematic\"\n        )\n        pipeline.load_lora_weights(\"nerijs/pixel-art-xl\", weight_name=\"pixel-art-xl.safetensors\", adapter_name=\"pixel\")\n        pipeline.set_adapters([\"cinematic\", \"pixel\"], adapter_weights=[0.5, 0.5])\n        ```\n        \"\"\"\n        if not USE_PEFT_BACKEND:\n            raise ValueError(\"PEFT backend is required for `set_adapters()`.\")\n\n        adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names\n\n        if weights is None:\n            weights = [1.0] * len(adapter_names)\n        elif isinstance(weights, float):\n            weights = [weights] * len(adapter_names)\n\n        if len(adapter_names) != len(weights):\n            raise ValueError(\n                f\"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}.\"\n            )\n\n        set_weights_and_activate_adapters(self, adapter_names, weights)\n\n    def disable_lora(self):\n        \"\"\"\n        Disable the UNet's active LoRA layers.\n\n        Example:\n\n        ```py\n        from diffusers import AutoPipelineForText2Image\n        import torch\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n        ).to(\"cuda\")\n        pipeline.load_lora_weights(\n            \"jbilcke-hf/sdxl-cinematic-1\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_name=\"cinematic\"\n        )\n        pipeline.disable_lora()\n        ```\n        \"\"\"\n        if not USE_PEFT_BACKEND:\n            raise ValueError(\"PEFT backend is required for this method.\")\n        set_adapter_layers(self, enabled=False)\n\n    def enable_lora(self):\n        \"\"\"\n        Enable the UNet's active LoRA layers.\n\n        Example:\n\n        ```py\n        from diffusers import AutoPipelineForText2Image\n        import torch\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n        ).to(\"cuda\")\n        pipeline.load_lora_weights(\n            \"jbilcke-hf/sdxl-cinematic-1\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_name=\"cinematic\"\n        )\n        pipeline.enable_lora()\n        ```\n        \"\"\"\n        if not USE_PEFT_BACKEND:\n            raise ValueError(\"PEFT backend is required for this method.\")\n        set_adapter_layers(self, enabled=True)\n\n    def delete_adapters(self, adapter_names: Union[List[str], str]):\n        \"\"\"\n        Delete an adapter's LoRA layers from the UNet.\n\n        Args:\n            adapter_names (`Union[List[str], str]`):\n                The names (single string or list of strings) of the adapter to delete.\n\n        Example:\n\n        ```py\n        from diffusers import AutoPipelineForText2Image\n        import torch\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n        ).to(\"cuda\")\n        pipeline.load_lora_weights(\n            \"jbilcke-hf/sdxl-cinematic-1\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_names=\"cinematic\"\n        )\n        pipeline.delete_adapters(\"cinematic\")\n        ```\n        \"\"\"\n        if not USE_PEFT_BACKEND:\n            raise ValueError(\"PEFT backend is required for this method.\")\n\n        if isinstance(adapter_names, str):\n            adapter_names = [adapter_names]\n\n        for adapter_name in adapter_names:\n            delete_adapter_layers(self, adapter_name)\n\n            # Pop also the corresponding adapter from the config\n            if hasattr(self, \"peft_config\"):\n                self.peft_config.pop(adapter_name, None)\n\n    def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):\n        if low_cpu_mem_usage:\n            if is_accelerate_available():\n                from accelerate import init_empty_weights\n\n            else:\n                low_cpu_mem_usage = False\n                logger.warning(\n                    \"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the\"\n                    \" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install\"\n                    \" `accelerate` for faster and less memory-intense model loading. You can do so with: \\n```\\npip\"\n                    \" install accelerate\\n```\\n.\"\n                )\n\n        if low_cpu_mem_usage is True and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `low_cpu_mem_usage=False`.\"\n            )\n\n        updated_state_dict = {}\n        image_projection = None\n        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext\n\n        if \"proj.weight\" in state_dict:\n            # IP-Adapter\n            num_image_text_embeds = 4\n            clip_embeddings_dim = state_dict[\"proj.weight\"].shape[-1]\n            cross_attention_dim = state_dict[\"proj.weight\"].shape[0] // num_image_text_embeds\n\n            with init_context():\n                image_projection = ImageProjection(\n                    cross_attention_dim=cross_attention_dim,\n                    image_embed_dim=clip_embeddings_dim,\n                    num_image_text_embeds=num_image_text_embeds,\n                )\n\n            for key, value in state_dict.items():\n                diffusers_name = key.replace(\"proj\", \"image_embeds\")\n                updated_state_dict[diffusers_name] = value\n\n        if not low_cpu_mem_usage:\n            image_projection.load_state_dict(updated_state_dict)\n        else:\n            load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)\n\n        return image_projection\n\n    def _convert_ip_adapter_attn_to_diffusers_VPAdapter(self, state_dicts, low_cpu_mem_usage=False):\n        from diffusers.models.attention_processor import (\n            AttnProcessor,\n            IPAdapterAttnProcessor,\n        )\n\n        if low_cpu_mem_usage:\n            if is_accelerate_available():\n                from accelerate import init_empty_weights\n\n            else:\n                low_cpu_mem_usage = False\n                logger.warning(\n                    \"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the\"\n                    \" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install\"\n                    \" `accelerate` for faster and less memory-intense model loading. You can do so with: \\n```\\npip\"\n                    \" install accelerate\\n```\\n.\"\n                )\n\n        if low_cpu_mem_usage is True and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `low_cpu_mem_usage=False`.\"\n            )\n\n        # set ip-adapter cross-attention processors & load state_dict\n        attn_procs = {}\n        key_id = 1\n        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext\n        for name in self.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else self.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = self.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(self.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = self.config.block_out_channels[block_id]\n\n            if cross_attention_dim is None or \"motion_modules\" in name or \"fuser\" in name:\n                attn_processor_class = (\n                    AttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else AttnProcessor\n                )\n                attn_procs[name] = attn_processor_class()\n            else:\n                attn_processor_class = (\n                    VPTemporalAdapterAttnProcessor2_0\n                    if hasattr(F, \"scaled_dot_product_attention\")\n                    else IPAdapterAttnProcessor\n                )\n                num_image_text_embeds = []\n                for state_dict in state_dicts:\n                    if \"proj.weight\" in state_dict[\"image_proj\"]:\n                        # IP-Adapter\n                        num_image_text_embeds += [4]\n                    elif \"proj.3.weight\" in state_dict[\"image_proj\"]:\n                        # IP-Adapter Full Face\n                        num_image_text_embeds += [257]  # 256 CLIP tokens + 1 CLS token\n                    else:\n                        # IP-Adapter Plus\n                        num_image_text_embeds += [state_dict[\"image_proj\"][\"latents\"].shape[1]]\n\n                with init_context():\n                    attn_procs[name] = attn_processor_class(\n                        hidden_size=hidden_size,\n                        cross_attention_dim=cross_attention_dim,\n                        scale=1.0,\n                        num_tokens=num_image_text_embeds,\n                    )\n\n                value_dict = {}\n                for i, state_dict in enumerate(state_dicts):\n                    value_dict.update({f\"to_k_ip.{i}.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_ip.weight\"]})\n                    value_dict.update({f\"to_v_ip.{i}.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_ip.weight\"]})\n\n                if not low_cpu_mem_usage:\n                    attn_procs[name].load_state_dict(value_dict)\n                else:\n                    device = next(iter(value_dict.values())).device\n                    dtype = next(iter(value_dict.values())).dtype\n                    load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)\n\n                key_id += 2\n\n        return attn_procs\n\n    def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):\n        from diffusers.models.attention_processor import (\n            AttnProcessor,\n            IPAdapterAttnProcessor,\n        )\n\n        if low_cpu_mem_usage:\n            if is_accelerate_available():\n                from accelerate import init_empty_weights\n\n            else:\n                low_cpu_mem_usage = False\n                logger.warning(\n                    \"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the\"\n                    \" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install\"\n                    \" `accelerate` for faster and less memory-intense model loading. You can do so with: \\n```\\npip\"\n                    \" install accelerate\\n```\\n.\"\n                )\n\n        if low_cpu_mem_usage is True and not is_torch_version(\">=\", \"1.9.0\"):\n            raise NotImplementedError(\n                \"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set\"\n                \" `low_cpu_mem_usage=False`.\"\n            )\n\n        # set ip-adapter cross-attention processors & load state_dict\n        attn_procs = {}\n        key_id = 1\n        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext\n        for name in self.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else self.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = self.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(self.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = self.config.block_out_channels[block_id]\n\n            if cross_attention_dim is None or \"motion_modules\" in name or \"fuser\" in name:\n                attn_processor_class = (\n                    AttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else AttnProcessor\n                )\n                attn_procs[name] = attn_processor_class()\n            else:\n                attn_processor_class = (\n                    IPAdapterAttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else IPAdapterAttnProcessor\n                )\n                num_image_text_embeds = []\n                for state_dict in state_dicts:\n                    if \"proj.weight\" in state_dict[\"image_proj\"]:\n                        # IP-Adapter\n                        num_image_text_embeds += [4]\n                    elif \"proj.3.weight\" in state_dict[\"image_proj\"]:\n                        # IP-Adapter Full Face\n                        num_image_text_embeds += [257]  # 256 CLIP tokens + 1 CLS token\n                    else:\n                        # IP-Adapter Plus\n                        num_image_text_embeds += [state_dict[\"image_proj\"][\"latents\"].shape[1]]\n\n                with init_context():\n                    attn_procs[name] = attn_processor_class(\n                        hidden_size=hidden_size,\n                        cross_attention_dim=cross_attention_dim,\n                        scale=1.0,\n                        num_tokens=num_image_text_embeds,\n                    )\n\n                value_dict = {}\n                for i, state_dict in enumerate(state_dicts):\n                    value_dict.update({f\"to_k_ip.{i}.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_ip.weight\"]})\n                    value_dict.update({f\"to_v_ip.{i}.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_ip.weight\"]})\n\n                if not low_cpu_mem_usage:\n                    attn_procs[name].load_state_dict(value_dict)\n                else:\n                    device = next(iter(value_dict.values())).device\n                    dtype = next(iter(value_dict.values())).dtype\n                    load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)\n\n                key_id += 2\n\n        return attn_procs\n\n    def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):\n        attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)\n        self.set_attn_processor(attn_procs)\n\n        # convert IP-Adapter Image Projection layers to diffusers\n        image_projection_layers = []\n        for state_dict in state_dicts:\n            image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(\n                state_dict[\"image_proj\"], low_cpu_mem_usage=low_cpu_mem_usage\n            )\n            image_projection_layers.append(image_projection_layer)\n\n        self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)\n        self.config.encoder_hid_dim_type = \"ip_image_proj\"\n\n        self.to(dtype=self.dtype, device=self.device)\n\n    def _load_ip_adapter_weights_VPAdapter(self, state_dicts, low_cpu_mem_usage=False):\n        attn_procs = self._convert_ip_adapter_attn_to_diffusers_VPAdapter(\n            state_dicts, low_cpu_mem_usage=low_cpu_mem_usage\n        )\n        self.set_attn_processor(attn_procs)\n\n        # convert IP-Adapter Image Projection layers to diffusers\n        image_projection_layers = []\n        for state_dict in state_dicts:\n            image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(\n                state_dict[\"image_proj\"], low_cpu_mem_usage=low_cpu_mem_usage\n            )\n            image_projection_layers.append(image_projection_layer)\n\n        self.encoder_hid_proj = VPAdapterImageProjection(image_projection_layers)\n        self.config.encoder_hid_dim_type = \"ip_image_proj\"\n\n        self.to(dtype=self.dtype, device=self.device)\n"
  },
  {
    "path": "foleycrafter/models/auffusion/resnet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.downsampling import (  # noqa\n    Downsample1D,\n    Downsample2D,\n    FirDownsample2D,\n    KDownsample2D,\n    downsample_2d,\n)\nfrom diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear\nfrom diffusers.models.normalization import AdaGroupNorm\nfrom diffusers.models.upsampling import (  # noqa\n    FirUpsample2D,\n    KUpsample2D,\n    Upsample1D,\n    Upsample2D,\n    upfirdn2d_native,\n    upsample_2d,\n)\nfrom diffusers.utils import USE_PEFT_BACKEND\nfrom foleycrafter.models.auffusion.attention_processor import SpatialNorm\n\n\nclass ResnetBlock2D(nn.Module):\n    r\"\"\"\n    A Resnet block.\n\n    Parameters:\n        in_channels (`int`): The number of channels in the input.\n        out_channels (`int`, *optional*, default to be `None`):\n            The number of output channels for the first conv2d layer. If None, same as `in_channels`.\n        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.\n        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.\n        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.\n        groups_out (`int`, *optional*, default to None):\n            The number of groups to use for the second normalization layer. if set to None, same as `groups`.\n        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.\n        non_linearity (`str`, *optional*, default to `\"swish\"`): the activation function to use.\n        time_embedding_norm (`str`, *optional*, default to `\"default\"` ): Time scale shift config.\n            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose \"scale_shift\" or\n            \"ada_group\" for a stronger conditioning with scale and shift.\n        kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see\n            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].\n        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.\n        use_in_shortcut (`bool`, *optional*, default to `True`):\n            If `True`, add a 1x1 nn.conv2d layer for skip-connection.\n        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.\n        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.\n        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the\n            `conv_shortcut` output.\n        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.\n            If None, same as `out_channels`.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        conv_shortcut: bool = False,\n        dropout: float = 0.0,\n        temb_channels: int = 512,\n        groups: int = 32,\n        groups_out: Optional[int] = None,\n        pre_norm: bool = True,\n        eps: float = 1e-6,\n        non_linearity: str = \"swish\",\n        skip_time_act: bool = False,\n        time_embedding_norm: str = \"default\",  # default, scale_shift, ada_group, spatial\n        kernel: Optional[torch.FloatTensor] = None,\n        output_scale_factor: float = 1.0,\n        use_in_shortcut: Optional[bool] = None,\n        up: bool = False,\n        down: bool = False,\n        conv_shortcut_bias: bool = True,\n        conv_2d_out_channels: Optional[int] = None,\n    ):\n        super().__init__()\n        self.pre_norm = pre_norm\n        self.pre_norm = True\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n        self.up = up\n        self.down = down\n        self.output_scale_factor = output_scale_factor\n        self.time_embedding_norm = time_embedding_norm\n        self.skip_time_act = skip_time_act\n\n        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear\n        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv\n\n        if groups_out is None:\n            groups_out = groups\n\n        if self.time_embedding_norm == \"ada_group\":\n            self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)\n        elif self.time_embedding_norm == \"spatial\":\n            self.norm1 = SpatialNorm(in_channels, temb_channels)\n        else:\n            self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)\n\n        self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)\n\n        if temb_channels is not None:\n            if self.time_embedding_norm == \"default\":\n                self.time_emb_proj = linear_cls(temb_channels, out_channels)\n            elif self.time_embedding_norm == \"scale_shift\":\n                self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)\n            elif self.time_embedding_norm == \"ada_group\" or self.time_embedding_norm == \"spatial\":\n                self.time_emb_proj = None\n            else:\n                raise ValueError(f\"unknown time_embedding_norm : {self.time_embedding_norm} \")\n        else:\n            self.time_emb_proj = None\n\n        if self.time_embedding_norm == \"ada_group\":\n            self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)\n        elif self.time_embedding_norm == \"spatial\":\n            self.norm2 = SpatialNorm(out_channels, temb_channels)\n        else:\n            self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)\n\n        self.dropout = torch.nn.Dropout(dropout)\n        conv_2d_out_channels = conv_2d_out_channels or out_channels\n        self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)\n\n        self.nonlinearity = get_activation(non_linearity)\n\n        self.upsample = self.downsample = None\n        if self.up:\n            if kernel == \"fir\":\n                fir_kernel = (1, 3, 3, 1)\n                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)\n            elif kernel == \"sde_vp\":\n                self.upsample = partial(F.interpolate, scale_factor=2.0, mode=\"nearest\")\n            else:\n                self.upsample = Upsample2D(in_channels, use_conv=False)\n        elif self.down:\n            if kernel == \"fir\":\n                fir_kernel = (1, 3, 3, 1)\n                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)\n            elif kernel == \"sde_vp\":\n                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)\n            else:\n                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name=\"op\")\n\n        self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut\n\n        self.conv_shortcut = None\n        if self.use_in_shortcut:\n            self.conv_shortcut = conv_cls(\n                in_channels,\n                conv_2d_out_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=conv_shortcut_bias,\n            )\n\n    def forward(\n        self,\n        input_tensor: torch.FloatTensor,\n        temb: torch.FloatTensor,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        hidden_states = input_tensor\n\n        if self.time_embedding_norm == \"ada_group\" or self.time_embedding_norm == \"spatial\":\n            hidden_states = self.norm1(hidden_states, temb)\n        else:\n            hidden_states = self.norm1(hidden_states)\n\n        hidden_states = self.nonlinearity(hidden_states)\n\n        if self.upsample is not None:\n            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984\n            if hidden_states.shape[0] >= 64:\n                input_tensor = input_tensor.contiguous()\n                hidden_states = hidden_states.contiguous()\n            input_tensor = (\n                self.upsample(input_tensor, scale=scale)\n                if isinstance(self.upsample, Upsample2D)\n                else self.upsample(input_tensor)\n            )\n            hidden_states = (\n                self.upsample(hidden_states, scale=scale)\n                if isinstance(self.upsample, Upsample2D)\n                else self.upsample(hidden_states)\n            )\n        elif self.downsample is not None:\n            input_tensor = (\n                self.downsample(input_tensor, scale=scale)\n                if isinstance(self.downsample, Downsample2D)\n                else self.downsample(input_tensor)\n            )\n            hidden_states = (\n                self.downsample(hidden_states, scale=scale)\n                if isinstance(self.downsample, Downsample2D)\n                else self.downsample(hidden_states)\n            )\n\n        hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)\n\n        if self.time_emb_proj is not None:\n            if not self.skip_time_act:\n                temb = self.nonlinearity(temb)\n            temb = (\n                self.time_emb_proj(temb, scale)[:, :, None, None]\n                if not USE_PEFT_BACKEND\n                # NOTE: Maybe we can use different prompt in different time\n                else self.time_emb_proj(temb)[:, :, None, None]\n            )\n\n        if temb is not None and self.time_embedding_norm == \"default\":\n            hidden_states = hidden_states + temb\n\n        if self.time_embedding_norm == \"ada_group\" or self.time_embedding_norm == \"spatial\":\n            hidden_states = self.norm2(hidden_states, temb)\n        else:\n            hidden_states = self.norm2(hidden_states)\n\n        if temb is not None and self.time_embedding_norm == \"scale_shift\":\n            scale, shift = torch.chunk(temb, 2, dim=1)\n            hidden_states = hidden_states * (1 + scale) + shift\n\n        hidden_states = self.nonlinearity(hidden_states)\n\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            input_tensor = (\n                self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)\n            )\n\n        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor\n\n        return output_tensor\n\n\n# unet_rl.py\ndef rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:\n    if len(tensor.shape) == 2:\n        return tensor[:, :, None]\n    if len(tensor.shape) == 3:\n        return tensor[:, :, None, :]\n    elif len(tensor.shape) == 4:\n        return tensor[:, :, 0, :]\n    else:\n        raise ValueError(f\"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.\")\n\n\nclass Conv1dBlock(nn.Module):\n    \"\"\"\n    Conv1d --> GroupNorm --> Mish\n\n    Parameters:\n        inp_channels (`int`): Number of input channels.\n        out_channels (`int`): Number of output channels.\n        kernel_size (`int` or `tuple`): Size of the convolving kernel.\n        n_groups (`int`, default `8`): Number of groups to separate the channels into.\n        activation (`str`, defaults to `mish`): Name of the activation function.\n    \"\"\"\n\n    def __init__(\n        self,\n        inp_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, Tuple[int, int]],\n        n_groups: int = 8,\n        activation: str = \"mish\",\n    ):\n        super().__init__()\n\n        self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)\n        self.group_norm = nn.GroupNorm(n_groups, out_channels)\n        self.mish = get_activation(activation)\n\n    def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n        intermediate_repr = self.conv1d(inputs)\n        intermediate_repr = rearrange_dims(intermediate_repr)\n        intermediate_repr = self.group_norm(intermediate_repr)\n        intermediate_repr = rearrange_dims(intermediate_repr)\n        output = self.mish(intermediate_repr)\n        return output\n\n\n# unet_rl.py\nclass ResidualTemporalBlock1D(nn.Module):\n    \"\"\"\n    Residual 1D block with temporal convolutions.\n\n    Parameters:\n        inp_channels (`int`): Number of input channels.\n        out_channels (`int`): Number of output channels.\n        embed_dim (`int`): Embedding dimension.\n        kernel_size (`int` or `tuple`): Size of the convolving kernel.\n        activation (`str`, defaults `mish`): It is possible to choose the right activation function.\n    \"\"\"\n\n    def __init__(\n        self,\n        inp_channels: int,\n        out_channels: int,\n        embed_dim: int,\n        kernel_size: Union[int, Tuple[int, int]] = 5,\n        activation: str = \"mish\",\n    ):\n        super().__init__()\n        self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)\n        self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)\n\n        self.time_emb_act = get_activation(activation)\n        self.time_emb = nn.Linear(embed_dim, out_channels)\n\n        self.residual_conv = (\n            nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()\n        )\n\n    def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            inputs : [ batch_size x inp_channels x horizon ]\n            t : [ batch_size x embed_dim ]\n\n        returns:\n            out : [ batch_size x out_channels x horizon ]\n        \"\"\"\n        t = self.time_emb_act(t)\n        t = self.time_emb(t)\n        out = self.conv_in(inputs) + rearrange_dims(t)\n        out = self.conv_out(out)\n        return out + self.residual_conv(inputs)\n\n\nclass TemporalConvLayer(nn.Module):\n    \"\"\"\n    Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:\n    https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016\n\n    Parameters:\n        in_dim (`int`): Number of input channels.\n        out_dim (`int`): Number of output channels.\n        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: Optional[int] = None,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n    ):\n        super().__init__()\n        out_dim = out_dim or in_dim\n        self.in_dim = in_dim\n        self.out_dim = out_dim\n\n        # conv layers\n        self.conv1 = nn.Sequential(\n            nn.GroupNorm(norm_num_groups, in_dim),\n            nn.SiLU(),\n            nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),\n        )\n        self.conv2 = nn.Sequential(\n            nn.GroupNorm(norm_num_groups, out_dim),\n            nn.SiLU(),\n            nn.Dropout(dropout),\n            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),\n        )\n        self.conv3 = nn.Sequential(\n            nn.GroupNorm(norm_num_groups, out_dim),\n            nn.SiLU(),\n            nn.Dropout(dropout),\n            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),\n        )\n        self.conv4 = nn.Sequential(\n            nn.GroupNorm(norm_num_groups, out_dim),\n            nn.SiLU(),\n            nn.Dropout(dropout),\n            nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),\n        )\n\n        # zero out the last layer params,so the conv block is identity\n        nn.init.zeros_(self.conv4[-1].weight)\n        nn.init.zeros_(self.conv4[-1].bias)\n\n    def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:\n        hidden_states = (\n            hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)\n        )\n\n        identity = hidden_states\n        hidden_states = self.conv1(hidden_states)\n        hidden_states = self.conv2(hidden_states)\n        hidden_states = self.conv3(hidden_states)\n        hidden_states = self.conv4(hidden_states)\n\n        hidden_states = identity + hidden_states\n\n        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(\n            (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]\n        )\n        return hidden_states\n\n\nclass TemporalResnetBlock(nn.Module):\n    r\"\"\"\n    A Resnet block.\n\n    Parameters:\n        in_channels (`int`): The number of channels in the input.\n        out_channels (`int`, *optional*, default to be `None`):\n            The number of output channels for the first conv2d layer. If None, same as `in_channels`.\n        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.\n        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        temb_channels: int = 512,\n        eps: float = 1e-6,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n\n        kernel_size = (3, 1, 1)\n        padding = [k // 2 for k in kernel_size]\n\n        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)\n        self.conv1 = nn.Conv3d(\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            padding=padding,\n        )\n\n        if temb_channels is not None:\n            self.time_emb_proj = nn.Linear(temb_channels, out_channels)\n        else:\n            self.time_emb_proj = None\n\n        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)\n\n        self.dropout = torch.nn.Dropout(0.0)\n        self.conv2 = nn.Conv3d(\n            out_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            padding=padding,\n        )\n\n        self.nonlinearity = get_activation(\"silu\")\n\n        self.use_in_shortcut = self.in_channels != out_channels\n\n        self.conv_shortcut = None\n        if self.use_in_shortcut:\n            self.conv_shortcut = nn.Conv3d(\n                in_channels,\n                out_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            )\n\n    def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:\n        hidden_states = input_tensor\n\n        hidden_states = self.norm1(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.conv1(hidden_states)\n\n        if self.time_emb_proj is not None:\n            temb = self.nonlinearity(temb)\n            temb = self.time_emb_proj(temb)[:, :, :, None, None]\n            temb = temb.permute(0, 2, 1, 3, 4)\n            hidden_states = hidden_states + temb\n\n        hidden_states = self.norm2(hidden_states)\n        hidden_states = self.nonlinearity(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            input_tensor = self.conv_shortcut(input_tensor)\n\n        output_tensor = input_tensor + hidden_states\n\n        return output_tensor\n\n\n# VideoResBlock\nclass SpatioTemporalResBlock(nn.Module):\n    r\"\"\"\n    A SpatioTemporal Resnet block.\n\n    Parameters:\n        in_channels (`int`): The number of channels in the input.\n        out_channels (`int`, *optional*, default to be `None`):\n            The number of output channels for the first conv2d layer. If None, same as `in_channels`.\n        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.\n        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.\n        temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.\n        merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.\n        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):\n            The merge strategy to use for the temporal mixing.\n        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):\n            If `True`, switch the spatial and temporal mixing.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: Optional[int] = None,\n        temb_channels: int = 512,\n        eps: float = 1e-6,\n        temporal_eps: Optional[float] = None,\n        merge_factor: float = 0.5,\n        merge_strategy=\"learned_with_images\",\n        switch_spatial_to_temporal_mix: bool = False,\n    ):\n        super().__init__()\n\n        self.spatial_res_block = ResnetBlock2D(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            eps=eps,\n        )\n\n        self.temporal_res_block = TemporalResnetBlock(\n            in_channels=out_channels if out_channels is not None else in_channels,\n            out_channels=out_channels if out_channels is not None else in_channels,\n            temb_channels=temb_channels,\n            eps=temporal_eps if temporal_eps is not None else eps,\n        )\n\n        self.time_mixer = AlphaBlender(\n            alpha=merge_factor,\n            merge_strategy=merge_strategy,\n            switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        image_only_indicator: Optional[torch.Tensor] = None,\n    ):\n        num_frames = image_only_indicator.shape[-1]\n        hidden_states = self.spatial_res_block(hidden_states, temb)\n\n        batch_frames, channels, height, width = hidden_states.shape\n        batch_size = batch_frames // num_frames\n\n        hidden_states_mix = (\n            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)\n        )\n        hidden_states = (\n            hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)\n        )\n\n        if temb is not None:\n            temb = temb.reshape(batch_size, num_frames, -1)\n\n        hidden_states = self.temporal_res_block(hidden_states, temb)\n        hidden_states = self.time_mixer(\n            x_spatial=hidden_states_mix,\n            x_temporal=hidden_states,\n            image_only_indicator=image_only_indicator,\n        )\n\n        hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)\n        return hidden_states\n\n\nclass AlphaBlender(nn.Module):\n    r\"\"\"\n    A module to blend spatial and temporal features.\n\n    Parameters:\n        alpha (`float`): The initial value of the blending factor.\n        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):\n            The merge strategy to use for the temporal mixing.\n        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):\n            If `True`, switch the spatial and temporal mixing.\n    \"\"\"\n\n    strategies = [\"learned\", \"fixed\", \"learned_with_images\"]\n\n    def __init__(\n        self,\n        alpha: float,\n        merge_strategy: str = \"learned_with_images\",\n        switch_spatial_to_temporal_mix: bool = False,\n    ):\n        super().__init__()\n        self.merge_strategy = merge_strategy\n        self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix  # For TemporalVAE\n\n        if merge_strategy not in self.strategies:\n            raise ValueError(f\"merge_strategy needs to be in {self.strategies}\")\n\n        if self.merge_strategy == \"fixed\":\n            self.register_buffer(\"mix_factor\", torch.Tensor([alpha]))\n        elif self.merge_strategy == \"learned\" or self.merge_strategy == \"learned_with_images\":\n            self.register_parameter(\"mix_factor\", torch.nn.Parameter(torch.Tensor([alpha])))\n        else:\n            raise ValueError(f\"Unknown merge strategy {self.merge_strategy}\")\n\n    def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:\n        if self.merge_strategy == \"fixed\":\n            alpha = self.mix_factor\n\n        elif self.merge_strategy == \"learned\":\n            alpha = torch.sigmoid(self.mix_factor)\n\n        elif self.merge_strategy == \"learned_with_images\":\n            if image_only_indicator is None:\n                raise ValueError(\"Please provide image_only_indicator to use learned_with_images merge strategy\")\n\n            alpha = torch.where(\n                image_only_indicator.bool(),\n                torch.ones(1, 1, device=image_only_indicator.device),\n                torch.sigmoid(self.mix_factor)[..., None],\n            )\n\n            # (batch, channel, frames, height, width)\n            if ndims == 5:\n                alpha = alpha[:, None, :, None, None]\n            # (batch*frames, height*width, channels)\n            elif ndims == 3:\n                alpha = alpha.reshape(-1)[:, None, None]\n            else:\n                raise ValueError(f\"Unexpected ndims {ndims}. Dimensions should be 3 or 5\")\n\n        else:\n            raise NotImplementedError\n\n        return alpha\n\n    def forward(\n        self,\n        x_spatial: torch.Tensor,\n        x_temporal: torch.Tensor,\n        image_only_indicator: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)\n        alpha = alpha.to(x_spatial.dtype)\n\n        if self.switch_spatial_to_temporal_mix:\n            alpha = 1.0 - alpha\n\n        x = alpha * x_spatial + (1.0 - alpha) * x_temporal\n        return x\n"
  },
  {
    "path": "foleycrafter/models/auffusion/transformer_2d.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection\nfrom diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.normalization import AdaLayerNormSingle\nfrom diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version\nfrom foleycrafter.models.auffusion.attention import BasicTransformerBlock\n\n\n@dataclass\nclass Transformer2DModelOutput(BaseOutput):\n    \"\"\"\n    The output of [`Transformer2DModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):\n            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability\n            distributions for the unnoised latent pixels.\n    \"\"\"\n\n    sample: torch.FloatTensor\n\n\nclass Transformer2DModel(ModelMixin, ConfigMixin):\n    \"\"\"\n    A 2D Transformer model for image-like data.\n\n    Parameters:\n        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.\n        in_channels (`int`, *optional*):\n            The number of channels in the input and output (specify if the input is **continuous**).\n        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.\n        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).\n            This is fixed during training since it is used to learn a number of position embeddings.\n        num_vector_embeds (`int`, *optional*):\n            The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).\n            Includes the class for the masked latent pixel.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to use in feed-forward.\n        num_embeds_ada_norm ( `int`, *optional*):\n            The number of diffusion steps used during training. Pass if at least one of the norm_layers is\n            `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are\n            added to the hidden states.\n\n            During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.\n        attention_bias (`bool`, *optional*):\n            Configure if the `TransformerBlocks` attention should contain a bias parameter.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        out_channels: Optional[int] = None,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        sample_size: Optional[int] = None,\n        num_vector_embeds: Optional[int] = None,\n        patch_size: Optional[int] = None,\n        activation_fn: str = \"geglu\",\n        num_embeds_ada_norm: Optional[int] = None,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        double_self_attention: bool = False,\n        upcast_attention: bool = False,\n        norm_type: str = \"layer_norm\",\n        norm_elementwise_affine: bool = True,\n        norm_eps: float = 1e-5,\n        attention_type: str = \"default\",\n        caption_channels: int = None,\n    ):\n        super().__init__()\n        self.use_linear_projection = use_linear_projection\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        inner_dim = num_attention_heads * attention_head_dim\n\n        conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv\n        linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear\n\n        # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`\n        # Define whether input is continuous or discrete depending on configuration\n        self.is_input_continuous = (in_channels is not None) and (patch_size is None)\n        self.is_input_vectorized = num_vector_embeds is not None\n        self.is_input_patches = in_channels is not None and patch_size is not None\n\n        if norm_type == \"layer_norm\" and num_embeds_ada_norm is not None:\n            deprecation_message = (\n                f\"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or\"\n                \" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config.\"\n                \" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect\"\n                \" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it\"\n                \" would be very nice if you could open a Pull request for the `transformer/config.json` file\"\n            )\n            deprecate(\"norm_type!=num_embeds_ada_norm\", \"1.0.0\", deprecation_message, standard_warn=False)\n            norm_type = \"ada_norm\"\n\n        if self.is_input_continuous and self.is_input_vectorized:\n            raise ValueError(\n                f\"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make\"\n                \" sure that either `in_channels` or `num_vector_embeds` is None.\"\n            )\n        elif self.is_input_vectorized and self.is_input_patches:\n            raise ValueError(\n                f\"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make\"\n                \" sure that either `num_vector_embeds` or `num_patches` is None.\"\n            )\n        elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:\n            raise ValueError(\n                f\"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:\"\n                f\" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None.\"\n            )\n\n        # 2. Define input layers\n        if self.is_input_continuous:\n            self.in_channels = in_channels\n\n            self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n            if use_linear_projection:\n                self.proj_in = linear_cls(in_channels, inner_dim)\n            else:\n                self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)\n        elif self.is_input_vectorized:\n            assert sample_size is not None, \"Transformer2DModel over discrete input must provide sample_size\"\n            assert num_vector_embeds is not None, \"Transformer2DModel over discrete input must provide num_embed\"\n\n            self.height = sample_size\n            self.width = sample_size\n            self.num_vector_embeds = num_vector_embeds\n            self.num_latent_pixels = self.height * self.width\n\n            self.latent_image_embedding = ImagePositionalEmbeddings(\n                num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width\n            )\n        elif self.is_input_patches:\n            assert sample_size is not None, \"Transformer2DModel over patched input must provide sample_size\"\n\n            self.height = sample_size\n            self.width = sample_size\n\n            self.patch_size = patch_size\n            interpolation_scale = self.config.sample_size // 64  # => 64 (= 512 pixart) has interpolation scale 1\n            interpolation_scale = max(interpolation_scale, 1)\n            self.pos_embed = PatchEmbed(\n                height=sample_size,\n                width=sample_size,\n                patch_size=patch_size,\n                in_channels=in_channels,\n                embed_dim=inner_dim,\n                interpolation_scale=interpolation_scale,\n            )\n\n        # 3. Define transformers blocks\n        self.transformer_blocks = nn.ModuleList(\n            [\n                # NOTE: remember to change\n                BasicTransformerBlock(\n                    inner_dim,\n                    num_attention_heads,\n                    attention_head_dim,\n                    dropout=dropout,\n                    cross_attention_dim=cross_attention_dim,\n                    activation_fn=activation_fn,\n                    num_embeds_ada_norm=num_embeds_ada_norm,\n                    attention_bias=attention_bias,\n                    only_cross_attention=only_cross_attention,\n                    double_self_attention=double_self_attention,\n                    upcast_attention=upcast_attention,\n                    norm_type=norm_type,\n                    norm_elementwise_affine=norm_elementwise_affine,\n                    norm_eps=norm_eps,\n                    attention_type=attention_type,\n                )\n                for d in range(num_layers)\n            ]\n        )\n\n        # 4. Define output layers\n        self.out_channels = in_channels if out_channels is None else out_channels\n        if self.is_input_continuous:\n            # TODO: should use out_channels for continuous projections\n            if use_linear_projection:\n                self.proj_out = linear_cls(inner_dim, in_channels)\n            else:\n                self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)\n        elif self.is_input_vectorized:\n            self.norm_out = nn.LayerNorm(inner_dim)\n            self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)\n        elif self.is_input_patches and norm_type != \"ada_norm_single\":\n            self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)\n            self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)\n            self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)\n        elif self.is_input_patches and norm_type == \"ada_norm_single\":\n            self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)\n            self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)\n            self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)\n\n        # 5. PixArt-Alpha blocks.\n        self.adaln_single = None\n        self.use_additional_conditions = False\n        if norm_type == \"ada_norm_single\":\n            self.use_additional_conditions = self.config.sample_size == 128\n            # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use\n            # additional conditions until we find better name\n            self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)\n\n        self.caption_projection = None\n        if caption_channels is not None:\n            self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)\n\n        self.gradient_checkpointing = False\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        added_cond_kwargs: Dict[str, torch.Tensor] = None,\n        class_labels: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ):\n        \"\"\"\n        The [`Transformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):\n                Input `hidden_states`.\n            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.LongTensor`, *optional*):\n                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.\n            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):\n                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in\n                `AdaLayerZeroNorm`.\n            cross_attention_kwargs ( `Dict[str, Any]`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            attention_mask ( `torch.Tensor`, *optional*):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            encoder_attention_mask ( `torch.Tensor`, *optional*):\n                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:\n\n                    * Mask `(batch, sequence_length)` True = keep, False = discard.\n                    * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.\n\n                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format\n                above. This bias will be added to the cross-attention scores.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a\n            `tuple` where the first element is the sample tensor.\n        \"\"\"\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.\n        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.\n        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None and attention_mask.ndim == 2:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # Retrieve lora scale.\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n\n        # 1. Input\n        if self.is_input_continuous:\n            batch, _, height, width = hidden_states.shape\n            inner_dim = hidden_states.shape[1]\n            residual = hidden_states\n\n            hidden_states = self.norm(hidden_states)\n            if not self.use_linear_projection:\n                hidden_states = (\n                    self.proj_in(hidden_states, scale=lora_scale)\n                    if not USE_PEFT_BACKEND\n                    else self.proj_in(hidden_states)\n                )\n                inner_dim = hidden_states.shape[1]\n                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)\n            else:\n                inner_dim = hidden_states.shape[1]\n                hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)\n                hidden_states = (\n                    self.proj_in(hidden_states, scale=lora_scale)\n                    if not USE_PEFT_BACKEND\n                    else self.proj_in(hidden_states)\n                )\n\n        elif self.is_input_vectorized:\n            hidden_states = self.latent_image_embedding(hidden_states)\n        elif self.is_input_patches:\n            height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size\n            self.height, self.width = height, width\n            hidden_states = self.pos_embed(hidden_states)\n\n            if self.adaln_single is not None:\n                if self.use_additional_conditions and added_cond_kwargs is None:\n                    raise ValueError(\n                        \"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.\"\n                    )\n                batch_size = hidden_states.shape[0]\n                timestep, embedded_timestep = self.adaln_single(\n                    timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype\n                )\n\n        if self.caption_projection is not None:\n            batch_size = hidden_states.shape[0]\n            encoder_hidden_states = self.caption_projection(encoder_hidden_states)\n            encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])\n        # 2. Blocks\n        for block in self.transformer_blocks:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    hidden_states,\n                    attention_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    timestep,\n                    cross_attention_kwargs,\n                    class_labels,\n                    **ckpt_kwargs,\n                )\n            else:\n                hidden_states = block(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    timestep=timestep,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    class_labels=class_labels,\n                )\n\n        # 3. Output\n        if self.is_input_continuous:\n            if not self.use_linear_projection:\n                hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n                hidden_states = (\n                    self.proj_out(hidden_states, scale=lora_scale)\n                    if not USE_PEFT_BACKEND\n                    else self.proj_out(hidden_states)\n                )\n            else:\n                hidden_states = (\n                    self.proj_out(hidden_states, scale=lora_scale)\n                    if not USE_PEFT_BACKEND\n                    else self.proj_out(hidden_states)\n                )\n                hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n\n            output = hidden_states + residual\n        elif self.is_input_vectorized:\n            hidden_states = self.norm_out(hidden_states)\n            logits = self.out(hidden_states)\n            # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)\n            logits = logits.permute(0, 2, 1)\n\n            # log(p(x_0))\n            output = F.log_softmax(logits.double(), dim=1).float()\n\n        if self.is_input_patches:\n            if self.config.norm_type != \"ada_norm_single\":\n                conditioning = self.transformer_blocks[0].norm1.emb(\n                    timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n                shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)\n                hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]\n                hidden_states = self.proj_out_2(hidden_states)\n            elif self.config.norm_type == \"ada_norm_single\":\n                shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)\n                hidden_states = self.norm_out(hidden_states)\n                # Modulation\n                hidden_states = hidden_states * (1 + scale) + shift\n                hidden_states = self.proj_out(hidden_states)\n                hidden_states = hidden_states.squeeze(1)\n\n            # unpatchify\n            if self.adaln_single is None:\n                height = width = int(hidden_states.shape[1] ** 0.5)\n            hidden_states = hidden_states.reshape(\n                shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)\n            )\n            hidden_states = torch.einsum(\"nhwpqc->nchpwq\", hidden_states)\n            output = hidden_states.reshape(\n                shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)\n            )\n\n        if not return_dict:\n            return (output,)\n\n        return Transformer2DModelOutput(sample=output)\n"
  },
  {
    "path": "foleycrafter/models/auffusion/unet_2d_blocks.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.normalization import AdaGroupNorm\nfrom diffusers.utils import is_torch_version, logging\nfrom diffusers.utils.torch_utils import apply_freeu\nfrom foleycrafter.models.auffusion.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0\nfrom foleycrafter.models.auffusion.dual_transformer_2d import DualTransformer2DModel\nfrom foleycrafter.models.auffusion.resnet import (\n    Downsample2D,\n    FirDownsample2D,\n    FirUpsample2D,\n    KDownsample2D,\n    KUpsample2D,\n    ResnetBlock2D,\n    Upsample2D,\n)\nfrom foleycrafter.models.auffusion.transformer_2d import Transformer2DModel\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef get_down_block(\n    down_block_type: str,\n    num_layers: int,\n    in_channels: int,\n    out_channels: int,\n    temb_channels: int,\n    add_downsample: bool,\n    resnet_eps: float,\n    resnet_act_fn: str,\n    transformer_layers_per_block: int = 1,\n    num_attention_heads: Optional[int] = None,\n    resnet_groups: Optional[int] = None,\n    cross_attention_dim: Optional[int] = None,\n    downsample_padding: Optional[int] = None,\n    dual_cross_attention: bool = False,\n    use_linear_projection: bool = False,\n    only_cross_attention: bool = False,\n    upcast_attention: bool = False,\n    resnet_time_scale_shift: str = \"default\",\n    attention_type: str = \"default\",\n    resnet_skip_time_act: bool = False,\n    resnet_out_scale_factor: float = 1.0,\n    cross_attention_norm: Optional[str] = None,\n    attention_head_dim: Optional[int] = None,\n    downsample_type: Optional[str] = None,\n    dropout: float = 0.0,\n):\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warn(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    down_block_type = down_block_type[7:] if down_block_type.startswith(\"UNetRes\") else down_block_type\n    if down_block_type == \"DownBlock2D\":\n        return DownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"ResnetDownsampleBlock2D\":\n        return ResnetDownsampleBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n        )\n    elif down_block_type == \"AttnDownBlock2D\":\n        if add_downsample is False:\n            downsample_type = None\n        else:\n            downsample_type = downsample_type or \"conv\"  # default to 'conv'\n        return AttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            downsample_type=downsample_type,\n        )\n    elif down_block_type == \"CrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnDownBlock2D\")\n        return CrossAttnDownBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n        )\n    elif down_block_type == \"SimpleCrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D\")\n        return SimpleCrossAttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n            only_cross_attention=only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n    elif down_block_type == \"SkipDownBlock2D\":\n        return SkipDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"AttnSkipDownBlock2D\":\n        return AttnSkipDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"DownEncoderBlock2D\":\n        return DownEncoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"AttnDownEncoderBlock2D\":\n        return AttnDownEncoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"KDownBlock2D\":\n        return KDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n        )\n    elif down_block_type == \"KCrossAttnDownBlock2D\":\n        return KCrossAttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            add_self_attention=True if not add_downsample else False,\n        )\n    raise ValueError(f\"{down_block_type} does not exist.\")\n\n\ndef get_up_block(\n    up_block_type: str,\n    num_layers: int,\n    in_channels: int,\n    out_channels: int,\n    prev_output_channel: int,\n    temb_channels: int,\n    add_upsample: bool,\n    resnet_eps: float,\n    resnet_act_fn: str,\n    resolution_idx: Optional[int] = None,\n    transformer_layers_per_block: int = 1,\n    num_attention_heads: Optional[int] = None,\n    resnet_groups: Optional[int] = None,\n    cross_attention_dim: Optional[int] = None,\n    dual_cross_attention: bool = False,\n    use_linear_projection: bool = False,\n    only_cross_attention: bool = False,\n    upcast_attention: bool = False,\n    resnet_time_scale_shift: str = \"default\",\n    attention_type: str = \"default\",\n    resnet_skip_time_act: bool = False,\n    resnet_out_scale_factor: float = 1.0,\n    cross_attention_norm: Optional[str] = None,\n    attention_head_dim: Optional[int] = None,\n    upsample_type: Optional[str] = None,\n    dropout: float = 0.0,\n) -> nn.Module:\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warn(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    up_block_type = up_block_type[7:] if up_block_type.startswith(\"UNetRes\") else up_block_type\n    if up_block_type == \"UpBlock2D\":\n        return UpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"ResnetUpsampleBlock2D\":\n        return ResnetUpsampleBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n        )\n    elif up_block_type == \"CrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnUpBlock2D\")\n        return CrossAttnUpBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n        )\n    elif up_block_type == \"SimpleCrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D\")\n        return SimpleCrossAttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n            only_cross_attention=only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n    elif up_block_type == \"AttnUpBlock2D\":\n        if add_upsample is False:\n            upsample_type = None\n        else:\n            upsample_type = upsample_type or \"conv\"  # default to 'conv'\n\n        return AttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            upsample_type=upsample_type,\n        )\n    elif up_block_type == \"SkipUpBlock2D\":\n        return SkipUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"AttnSkipUpBlock2D\":\n        return AttnSkipUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"UpDecoderBlock2D\":\n        return UpDecoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n    elif up_block_type == \"AttnUpDecoderBlock2D\":\n        return AttnUpDecoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n    elif up_block_type == \"KUpBlock2D\":\n        return KUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n        )\n    elif up_block_type == \"KCrossAttnUpBlock2D\":\n        return KCrossAttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n        )\n\n    raise ValueError(f\"{up_block_type} does not exist.\")\n\n\nclass AutoencoderTinyBlock(nn.Module):\n    \"\"\"\n    Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU\n    blocks.\n\n    Args:\n        in_channels (`int`): The number of input channels.\n        out_channels (`int`): The number of output channels.\n        act_fn (`str`):\n            ` The activation function to use. Supported values are `\"swish\"`, `\"mish\"`, `\"gelu\"`, and `\"relu\"`.\n\n    Returns:\n        `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to\n        `out_channels`.\n    \"\"\"\n\n    def __init__(self, in_channels: int, out_channels: int, act_fn: str):\n        super().__init__()\n        act_fn = get_activation(act_fn)\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n            act_fn,\n            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n            act_fn,\n            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n        )\n        self.skip = (\n            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)\n            if in_channels != out_channels\n            else nn.Identity()\n        )\n        self.fuse = nn.ReLU()\n\n    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:\n        return self.fuse(self.conv(x) + self.skip(x))\n\n\nclass UNetMidBlock2D(nn.Module):\n    \"\"\"\n    A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.\n\n    Args:\n        in_channels (`int`): The number of input channels.\n        temb_channels (`int`): The number of temporal embedding channels.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout rate.\n        num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.\n        resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.\n        resnet_time_scale_shift (`str`, *optional*, defaults to `default`):\n            The type of normalization to apply to the time embeddings. This can help to improve the performance of the\n            model on tasks with long-range temporal dependencies.\n        resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.\n        resnet_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use in the group normalization layers of the resnet blocks.\n        attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.\n        resnet_pre_norm (`bool`, *optional*, defaults to `True`):\n            Whether to use pre-normalization for the resnet blocks.\n        add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.\n        attention_head_dim (`int`, *optional*, defaults to 1):\n            Dimension of a single attention head. The number of attention heads is determined based on this value and\n            the number of input channels.\n        output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.\n\n    Returns:\n        `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,\n        in_channels, height, width)`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",  # default, spatial\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        attn_groups: Optional[int] = None,\n        resnet_pre_norm: bool = True,\n        add_attention: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n    ):\n        super().__init__()\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n        self.add_attention = add_attention\n\n        if attn_groups is None:\n            attn_groups = resnet_groups if resnet_time_scale_shift == \"default\" else None\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.\"\n            )\n            attention_head_dim = in_channels\n\n        for _ in range(num_layers):\n            if self.add_attention:\n                attentions.append(\n                    Attention(\n                        in_channels,\n                        heads=in_channels // attention_head_dim,\n                        dim_head=attention_head_dim,\n                        rescale_output_factor=output_scale_factor,\n                        eps=resnet_eps,\n                        norm_num_groups=attn_groups,\n                        spatial_norm_dim=temb_channels if resnet_time_scale_shift == \"spatial\" else None,\n                        residual_connection=True,\n                        bias=True,\n                        upcast_softmax=True,\n                        _from_deprecated_attn_block=True,\n                    )\n                )\n            else:\n                attentions.append(None)\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n    def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if attn is not None:\n                hidden_states = attn(hidden_states, temb=temb)\n            hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass UNetMidBlock2DCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads: int = 1,\n        output_scale_factor: float = 1.0,\n        cross_attention_dim: int = 1280,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        upcast_attention: bool = False,\n        attention_type: str = \"default\",\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        # support for variable transformer layers per block\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * num_layers\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n\n        for i in range(num_layers):\n            if not dual_cross_attention:\n                attentions.append(\n                    Transformer2DModel(\n                        num_attention_heads,\n                        in_channels // num_attention_heads,\n                        in_channels=in_channels,\n                        num_layers=transformer_layers_per_block[i],\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        upcast_attention=upcast_attention,\n                        attention_type=attention_type,\n                    )\n                )\n            else:\n                attentions.append(\n                    DualTransformer2DModel(\n                        num_attention_heads,\n                        in_channels // num_attention_heads,\n                        in_channels=in_channels,\n                        num_layers=1,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                    )\n                )\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n        hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n            else:\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n\n        return hidden_states\n\n\nclass UNetMidBlock2DSimpleCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n        cross_attention_dim: int = 1280,\n        skip_time_act: bool = False,\n        only_cross_attention: bool = False,\n        cross_attention_norm: Optional[str] = None,\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n\n        self.attention_head_dim = attention_head_dim\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        self.num_heads = in_channels // self.attention_head_dim\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                skip_time_act=skip_time_act,\n            )\n        ]\n        attentions = []\n\n        for _ in range(num_layers):\n            processor = (\n                AttnAddedKVProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") else AttnAddedKVProcessor()\n            )\n\n            attentions.append(\n                Attention(\n                    query_dim=in_channels,\n                    cross_attention_dim=in_channels,\n                    heads=self.num_heads,\n                    dim_head=self.attention_head_dim,\n                    added_kv_proj_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    bias=True,\n                    upcast_softmax=True,\n                    only_cross_attention=only_cross_attention,\n                    cross_attention_norm=cross_attention_norm,\n                    processor=processor,\n                )\n            )\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0)\n\n        if attention_mask is None:\n            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.\n            mask = None if encoder_hidden_states is None else encoder_attention_mask\n        else:\n            # when attention_mask is defined: we don't even check for encoder_attention_mask.\n            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.\n            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.\n            #       then we can simplify this whole if/else block to:\n            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask\n            mask = attention_mask\n\n        hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            # attn\n            hidden_states = attn(\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=mask,\n                **cross_attention_kwargs,\n            )\n\n            # resnet\n            hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n\n        return hidden_states\n\n\nclass AttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n        downsample_padding: int = 1,\n        downsample_type: str = \"conv\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n        self.downsample_type = downsample_type\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if downsample_type == \"conv\":\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        elif downsample_type == \"resnet\":\n            self.downsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        down=True,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0)\n\n        output_states = ()\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            cross_attention_kwargs.update({\"scale\": lora_scale})\n            hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n            hidden_states = attn(hidden_states, **cross_attention_kwargs)\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                if self.downsample_type == \"resnet\":\n                    hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)\n                else:\n                    hidden_states = downsampler(hidden_states, scale=lora_scale)\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass CrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads: int = 1,\n        cross_attention_dim: int = 1280,\n        output_scale_factor: float = 1.0,\n        downsample_padding: int = 1,\n        add_downsample: bool = True,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n        attention_type: str = \"default\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * num_layers\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if not dual_cross_attention:\n                # Transformer2DModelWithSwitcher\n                attentions.append(\n                    Transformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=transformer_layers_per_block[i],\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        only_cross_attention=only_cross_attention,\n                        upcast_attention=upcast_attention,\n                        attention_type=attention_type,\n                    )\n                )\n            else:\n                attentions.append(\n                    DualTransformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=1,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                    )\n                )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        additional_residuals: Optional[torch.FloatTensor] = None,\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        output_states = ()\n\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n\n        blocks = list(zip(self.resnets, self.attentions))\n\n        for i, (resnet, attn) in enumerate(blocks):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n            # apply additional residuals to the output of the last pair of resnet and attention blocks\n            if i == len(blocks) - 1 and additional_residuals is not None:\n                hidden_states = hidden_states + additional_residuals\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, scale=lora_scale)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass DownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_downsample: bool = True,\n        downsample_padding: int = 1,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        output_states = ()\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=scale)\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, scale=scale)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass DownEncoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_downsample: bool = True,\n        downsample_padding: int = 1,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=None,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb=None, scale=scale)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, scale)\n\n        return hidden_states\n\n\nclass AttnDownEncoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n        add_downsample: bool = True,\n        downsample_padding: int = 1,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=None,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb=None, scale=scale)\n            cross_attention_kwargs = {\"scale\": scale}\n            hidden_states = attn(hidden_states, **cross_attention_kwargs)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, scale)\n\n        return hidden_states\n\n\nclass AttnSkipDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = np.sqrt(2.0),\n        add_downsample: bool = True,\n    ):\n        super().__init__()\n        self.attentions = nn.ModuleList([])\n        self.resnets = nn.ModuleList([])\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min(in_channels // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            self.attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=32,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        if add_downsample:\n            self.resnet_down = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                down=True,\n                kernel=\"fir\",\n            )\n            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])\n            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))\n        else:\n            self.resnet_down = None\n            self.downsamplers = None\n            self.skip_conv = None\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        skip_sample: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:\n        output_states = ()\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb, scale=scale)\n            cross_attention_kwargs = {\"scale\": scale}\n            hidden_states = attn(hidden_states, **cross_attention_kwargs)\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            hidden_states = self.resnet_down(hidden_states, temb, scale=scale)\n            for downsampler in self.downsamplers:\n                skip_sample = downsampler(skip_sample)\n\n            hidden_states = self.skip_conv(skip_sample) + hidden_states\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states, skip_sample\n\n\nclass SkipDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = np.sqrt(2.0),\n        add_downsample: bool = True,\n        downsample_padding: int = 1,\n    ):\n        super().__init__()\n        self.resnets = nn.ModuleList([])\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min(in_channels // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        if add_downsample:\n            self.resnet_down = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                down=True,\n                kernel=\"fir\",\n            )\n            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])\n            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))\n        else:\n            self.resnet_down = None\n            self.downsamplers = None\n            self.skip_conv = None\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        skip_sample: Optional[torch.FloatTensor] = None,\n        scale: float = 1.0,\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:\n        output_states = ()\n\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb, scale)\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            hidden_states = self.resnet_down(hidden_states, temb, scale)\n            for downsampler in self.downsamplers:\n                skip_sample = downsampler(skip_sample)\n\n            hidden_states = self.skip_conv(skip_sample) + hidden_states\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states, skip_sample\n\n\nclass ResnetDownsampleBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_downsample: bool = True,\n        skip_time_act: bool = False,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        down=True,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        output_states = ()\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale)\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, temb, scale)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass SimpleCrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        cross_attention_dim: int = 1280,\n        output_scale_factor: float = 1.0,\n        add_downsample: bool = True,\n        skip_time_act: bool = False,\n        only_cross_attention: bool = False,\n        cross_attention_norm: Optional[str] = None,\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n\n        resnets = []\n        attentions = []\n\n        self.attention_head_dim = attention_head_dim\n        self.num_heads = out_channels // self.attention_head_dim\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n            processor = (\n                AttnAddedKVProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") else AttnAddedKVProcessor()\n            )\n\n            attentions.append(\n                Attention(\n                    query_dim=out_channels,\n                    cross_attention_dim=out_channels,\n                    heads=self.num_heads,\n                    dim_head=attention_head_dim,\n                    added_kv_proj_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    bias=True,\n                    upcast_softmax=True,\n                    only_cross_attention=only_cross_attention,\n                    cross_attention_norm=cross_attention_norm,\n                    processor=processor,\n                )\n            )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        down=True,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        output_states = ()\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0)\n\n        if attention_mask is None:\n            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.\n            mask = None if encoder_hidden_states is None else encoder_attention_mask\n        else:\n            # when attention_mask is defined: we don't even check for encoder_attention_mask.\n            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.\n            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.\n            #       then we can simplify this whole if/else block to:\n            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask\n            mask = attention_mask\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, temb, scale=lora_scale)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass KDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 4,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n        resnet_group_size: int = 32,\n        add_downsample: bool = False,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    dropout=dropout,\n                    temb_channels=temb_channels,\n                    groups=groups,\n                    groups_out=groups_out,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            # YiYi's comments- might be able to use FirDownsample2D, look into details later\n            self.downsamplers = nn.ModuleList([KDownsample2D()])\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        output_states = ()\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale)\n\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass KCrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        cross_attention_dim: int,\n        dropout: float = 0.0,\n        num_layers: int = 4,\n        resnet_group_size: int = 32,\n        add_downsample: bool = True,\n        attention_head_dim: int = 64,\n        add_self_attention: bool = False,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    dropout=dropout,\n                    temb_channels=temb_channels,\n                    groups=groups,\n                    groups_out=groups_out,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n            attentions.append(\n                KAttentionBlock(\n                    out_channels,\n                    out_channels // attention_head_dim,\n                    attention_head_dim,\n                    cross_attention_dim=cross_attention_dim,\n                    temb_channels=temb_channels,\n                    attention_bias=True,\n                    add_self_attention=add_self_attention,\n                    cross_attention_norm=\"layer_norm\",\n                    group_size=resnet_group_size,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.attentions = nn.ModuleList(attentions)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList([KDownsample2D()])\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:\n        output_states = ()\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n\n            if self.downsamplers is None:\n                output_states += (None,)\n            else:\n                output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass AttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: int = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n        upsample_type: str = \"conv\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.upsample_type = upsample_type\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if upsample_type == \"conv\":\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        elif upsample_type == \"resnet\":\n            self.upsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        up=True,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb, scale=scale)\n            cross_attention_kwargs = {\"scale\": scale}\n            hidden_states = attn(hidden_states, **cross_attention_kwargs)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                if self.upsample_type == \"resnet\":\n                    hidden_states = upsampler(hidden_states, temb=temb, scale=scale)\n                else:\n                    hidden_states = upsampler(hidden_states, scale=scale)\n\n        return hidden_states\n\n\nclass CrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads: int = 1,\n        cross_attention_dim: int = 1280,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n        attention_type: str = \"default\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * num_layers\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if not dual_cross_attention:\n                # Transformer2DModelWithSwitcher\n                attentions.append(\n                    Transformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=transformer_layers_per_block[i],\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        only_cross_attention=only_cross_attention,\n                        upcast_attention=upcast_attention,\n                        attention_type=attention_type,\n                    )\n                )\n            else:\n                attentions.append(\n                    DualTransformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=1,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                    )\n                )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n        is_freeu_enabled = (\n            getattr(self, \"s1\", None)\n            and getattr(self, \"s2\", None)\n            and getattr(self, \"b1\", None)\n            and getattr(self, \"b2\", None)\n        )\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n\n            # FreeU: Only operate on the first two stages\n            if is_freeu_enabled:\n                hidden_states, res_hidden_states = apply_freeu(\n                    self.resolution_idx,\n                    hidden_states,\n                    res_hidden_states,\n                    s1=self.s1,\n                    s2=self.s2,\n                    b1=self.b1,\n                    b2=self.b2,\n                )\n\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)\n\n        return hidden_states\n\n\nclass UpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        is_freeu_enabled = (\n            getattr(self, \"s1\", None)\n            and getattr(self, \"s2\", None)\n            and getattr(self, \"b1\", None)\n            and getattr(self, \"b2\", None)\n        )\n\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n\n            # FreeU: Only operate on the first two stages\n            if is_freeu_enabled:\n                hidden_states, res_hidden_states = apply_freeu(\n                    self.resolution_idx,\n                    hidden_states,\n                    res_hidden_states,\n                    s1=self.s1,\n                    s2=self.s2,\n                    b1=self.b1,\n                    b2=self.b2,\n                )\n\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=scale)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size, scale=scale)\n\n        return hidden_states\n\n\nclass UpDecoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",  # default, spatial\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        temb_channels: Optional[int] = None,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            input_channels = in_channels if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=input_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0\n    ) -> torch.FloatTensor:\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb=temb, scale=scale)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass AttnUpDecoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        temb_channels: Optional[int] = None,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            input_channels = in_channels if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=input_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups if resnet_time_scale_shift != \"spatial\" else None,\n                    spatial_norm_dim=temb_channels if resnet_time_scale_shift == \"spatial\" else None,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0\n    ) -> torch.FloatTensor:\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb=temb, scale=scale)\n            cross_attention_kwargs = {\"scale\": scale}\n            hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, scale=scale)\n\n        return hidden_states\n\n\nclass AttnSkipUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        output_scale_factor: float = np.sqrt(2.0),\n        add_upsample: bool = True,\n    ):\n        super().__init__()\n        self.attentions = nn.ModuleList([])\n        self.resnets = nn.ModuleList([])\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min(resnet_in_channels + res_skip_channels // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        self.attentions.append(\n            Attention(\n                out_channels,\n                heads=out_channels // attention_head_dim,\n                dim_head=attention_head_dim,\n                rescale_output_factor=output_scale_factor,\n                eps=resnet_eps,\n                norm_num_groups=32,\n                residual_connection=True,\n                bias=True,\n                upcast_softmax=True,\n                _from_deprecated_attn_block=True,\n            )\n        )\n\n        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)\n        if add_upsample:\n            self.resnet_up = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                groups_out=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                up=True,\n                kernel=\"fir\",\n            )\n            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n            self.skip_norm = torch.nn.GroupNorm(\n                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True\n            )\n            self.act = nn.SiLU()\n        else:\n            self.resnet_up = None\n            self.skip_conv = None\n            self.skip_norm = None\n            self.act = None\n\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        skip_sample=None,\n        scale: float = 1.0,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb, scale=scale)\n\n        cross_attention_kwargs = {\"scale\": scale}\n        hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)\n\n        if skip_sample is not None:\n            skip_sample = self.upsampler(skip_sample)\n        else:\n            skip_sample = 0\n\n        if self.resnet_up is not None:\n            skip_sample_states = self.skip_norm(hidden_states)\n            skip_sample_states = self.act(skip_sample_states)\n            skip_sample_states = self.skip_conv(skip_sample_states)\n\n            skip_sample = skip_sample + skip_sample_states\n\n            hidden_states = self.resnet_up(hidden_states, temb, scale=scale)\n\n        return hidden_states, skip_sample\n\n\nclass SkipUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = np.sqrt(2.0),\n        add_upsample: bool = True,\n        upsample_padding: int = 1,\n    ):\n        super().__init__()\n        self.resnets = nn.ModuleList([])\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min((resnet_in_channels + res_skip_channels) // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)\n        if add_upsample:\n            self.resnet_up = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                groups_out=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                up=True,\n                kernel=\"fir\",\n            )\n            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n            self.skip_norm = torch.nn.GroupNorm(\n                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True\n            )\n            self.act = nn.SiLU()\n        else:\n            self.resnet_up = None\n            self.skip_conv = None\n            self.skip_norm = None\n            self.act = None\n\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        skip_sample=None,\n        scale: float = 1.0,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb, scale=scale)\n\n        if skip_sample is not None:\n            skip_sample = self.upsampler(skip_sample)\n        else:\n            skip_sample = 0\n\n        if self.resnet_up is not None:\n            skip_sample_states = self.skip_norm(hidden_states)\n            skip_sample_states = self.act(skip_sample_states)\n            skip_sample_states = self.skip_conv(skip_sample_states)\n\n            skip_sample = skip_sample + skip_sample_states\n\n            hidden_states = self.resnet_up(hidden_states, temb, scale=scale)\n\n        return hidden_states, skip_sample\n\n\nclass ResnetUpsampleBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        skip_time_act: bool = False,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        up=True,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=scale)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, temb, scale=scale)\n\n        return hidden_states\n\n\nclass SimpleCrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim: int = 1,\n        cross_attention_dim: int = 1280,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        skip_time_act: bool = False,\n        only_cross_attention: bool = False,\n        cross_attention_norm: Optional[str] = None,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.attention_head_dim = attention_head_dim\n\n        self.num_heads = out_channels // self.attention_head_dim\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n            processor = (\n                AttnAddedKVProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") else AttnAddedKVProcessor()\n            )\n\n            attentions.append(\n                Attention(\n                    query_dim=out_channels,\n                    cross_attention_dim=out_channels,\n                    heads=self.num_heads,\n                    dim_head=self.attention_head_dim,\n                    added_kv_proj_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    bias=True,\n                    upcast_softmax=True,\n                    only_cross_attention=only_cross_attention,\n                    cross_attention_norm=cross_attention_norm,\n                    processor=processor,\n                )\n            )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        up=True,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0)\n        if attention_mask is None:\n            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.\n            mask = None if encoder_hidden_states is None else encoder_attention_mask\n        else:\n            # when attention_mask is defined: we don't even check for encoder_attention_mask.\n            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.\n            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.\n            #       then we can simplify this whole if/else block to:\n            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask\n            mask = attention_mask\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # resnet\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, temb, scale=lora_scale)\n\n        return hidden_states\n\n\nclass KUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: int,\n        dropout: float = 0.0,\n        num_layers: int = 5,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n        resnet_group_size: Optional[int] = 32,\n        add_upsample: bool = True,\n    ):\n        super().__init__()\n        resnets = []\n        k_in_channels = 2 * out_channels\n        k_out_channels = in_channels\n        num_layers = num_layers - 1\n\n        for i in range(num_layers):\n            in_channels = k_in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=k_out_channels if (i == num_layers - 1) else out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=groups,\n                    groups_out=groups_out,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([KUpsample2D()])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        scale: float = 1.0,\n    ) -> torch.FloatTensor:\n        res_hidden_states_tuple = res_hidden_states_tuple[-1]\n        if res_hidden_states_tuple is not None:\n            hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=scale)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass KCrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        resolution_idx: int,\n        dropout: float = 0.0,\n        num_layers: int = 4,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n        resnet_group_size: int = 32,\n        attention_head_dim: int = 1,  # attention dim_head\n        cross_attention_dim: int = 768,\n        add_upsample: bool = True,\n        upcast_attention: bool = False,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        is_first_block = in_channels == out_channels == temb_channels\n        is_middle_block = in_channels != out_channels\n        add_self_attention = True if is_first_block else False\n\n        self.has_cross_attention = True\n        self.attention_head_dim = attention_head_dim\n\n        # in_channels, and out_channels for the block (k-unet)\n        k_in_channels = out_channels if is_first_block else 2 * out_channels\n        k_out_channels = in_channels\n\n        num_layers = num_layers - 1\n\n        for i in range(num_layers):\n            in_channels = k_in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            if is_middle_block and (i == num_layers - 1):\n                conv_2d_out_channels = k_out_channels\n            else:\n                conv_2d_out_channels = None\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    conv_2d_out_channels=conv_2d_out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=groups,\n                    groups_out=groups_out,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n            attentions.append(\n                KAttentionBlock(\n                    k_out_channels if (i == num_layers - 1) else out_channels,\n                    k_out_channels // attention_head_dim\n                    if (i == num_layers - 1)\n                    else out_channels // attention_head_dim,\n                    attention_head_dim,\n                    cross_attention_dim=cross_attention_dim,\n                    temb_channels=temb_channels,\n                    attention_bias=True,\n                    add_self_attention=add_self_attention,\n                    cross_attention_norm=\"layer_norm\",\n                    upcast_attention=upcast_attention,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.attentions = nn.ModuleList(attentions)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([KUpsample2D()])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        res_hidden_states_tuple = res_hidden_states_tuple[-1]\n        if res_hidden_states_tuple is not None:\n            hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)\n\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n        for resnet, attn in zip(self.resnets, self.attentions):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb, scale=lora_scale)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\n# can potentially later be renamed to `No-feed-forward` attention\nclass KAttentionBlock(nn.Module):\n    r\"\"\"\n    A basic Transformer block.\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Configure if the attention layers should contain a bias parameter.\n        upcast_attention (`bool`, *optional*, defaults to `False`):\n            Set to `True` to upcast the attention computation to `float32`.\n        temb_channels (`int`, *optional*, defaults to 768):\n            The number of channels in the token embedding.\n        add_self_attention (`bool`, *optional*, defaults to `False`):\n            Set to `True` to add self-attention to the block.\n        cross_attention_norm (`str`, *optional*, defaults to `None`):\n            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.\n        group_size (`int`, *optional*, defaults to 32):\n            The number of groups to separate the channels into for group normalization.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout: float = 0.0,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        upcast_attention: bool = False,\n        temb_channels: int = 768,  # for ada_group_norm\n        add_self_attention: bool = False,\n        cross_attention_norm: Optional[str] = None,\n        group_size: int = 32,\n    ):\n        super().__init__()\n        self.add_self_attention = add_self_attention\n\n        # 1. Self-Attn\n        if add_self_attention:\n            self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))\n            self.attn1 = Attention(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                cross_attention_dim=None,\n                cross_attention_norm=None,\n            )\n\n        # 2. Cross-Attn\n        self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))\n        self.attn2 = Attention(\n            query_dim=dim,\n            cross_attention_dim=cross_attention_dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            upcast_attention=upcast_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n\n    def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:\n        return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)\n\n    def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:\n        return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        # TODO: mark emb as non-optional (self.norm2 requires it).\n        #       requires assessing impact of change to positional param interface.\n        emb: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        # 1. Self-Attention\n        if self.add_self_attention:\n            norm_hidden_states = self.norm1(hidden_states, emb)\n\n            height, weight = norm_hidden_states.shape[2:]\n            norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)\n\n            attn_output = self.attn1(\n                norm_hidden_states,\n                encoder_hidden_states=None,\n                attention_mask=attention_mask,\n                **cross_attention_kwargs,\n            )\n            attn_output = self._to_4d(attn_output, height, weight)\n\n            hidden_states = attn_output + hidden_states\n\n        # 2. Cross-Attention/None\n        norm_hidden_states = self.norm2(hidden_states, emb)\n\n        height, weight = norm_hidden_states.shape[2:]\n        norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)\n        attn_output = self.attn2(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,\n            **cross_attention_kwargs,\n        )\n        attn_output = self._to_4d(attn_output, height, weight)\n\n        hidden_states = attn_output + hidden_states\n\n        return hidden_states\n"
  },
  {
    "path": "foleycrafter/models/auffusion_unet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.activations import get_activation\n\n# from diffusers import StableDiffusionGLIGENPipeline\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    Attention,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n    XFormersAttnProcessor,\n)\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom foleycrafter.models.adapters.ip_adapter import TimeProjModel\nfrom foleycrafter.models.auffusion.attention_processor import AttnProcessor2_0\nfrom foleycrafter.models.auffusion.loaders.unet import UNet2DConditionLoadersMixin\nfrom foleycrafter.models.auffusion.unet_2d_blocks import (\n    UNetMidBlock2D,\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNet2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n       reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling\n            blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,\n        *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,\n        *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlock2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        dropout: float = 0.0,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,\n        reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        attention_type: str = \"default\",\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n        # param for joint\n        video_feature_dim: tuple = (320, 640, 1280, 1280),\n        video_cross_attn_dim: int = 1024,\n        video_frame_nums: int = 16,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:\n            for layer_number_per_block in transformer_layers_per_block:\n                if isinstance(layer_number_per_block, list):\n                    raise ValueError(\"Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.\")\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                dropout=dropout,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                dropout=dropout,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                attention_type=attention_type,\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                dropout=dropout,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type == \"UNetMidBlock2D\":\n            self.mid_block = UNetMidBlock2D(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                dropout=dropout,\n                num_layers=0,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                add_attention=False,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        reversed_transformer_layers_per_block = (\n            list(reversed(transformer_layers_per_block))\n            if reverse_transformer_layers_per_block is None\n            else reverse_transformer_layers_per_block\n        )\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resolution_idx=i,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                dropout=dropout,\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n        if attention_type in [\"gated\", \"gated-text-image\"]:\n            positive_len = 768\n            if isinstance(cross_attention_dim, int):\n                positive_len = cross_attention_dim\n            elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):\n                positive_len = cross_attention_dim[0]\n\n            feature_type = \"text-only\" if attention_type == \"gated\" else \"text-image\"\n            self.position_net = TimeProjModel(\n                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type\n            )\n\n        # additional settings\n        self.video_feature_dim = video_feature_dim\n        self.cross_attention_dim = cross_attention_dim\n        self.video_cross_attn_dim = video_cross_attn_dim\n        self.video_frame_nums = video_frame_nums\n\n        self.multi_frames_condition = False\n\n    def load_attention(self):\n        attn_dict = {}\n        for name in self.attn_processors.keys():\n            # if self-attention, save feature\n            if name.endswith(\"attn1.processor\"):\n                if is_xformers_available():\n                    attn_dict[name] = XFormersAttnProcessor()\n                else:\n                    attn_dict[name] = AttnProcessor()\n            else:\n                attn_dict[name] = AttnProcessor2_0()\n        self.set_attn_processor(attn_dict)\n\n    def get_writer_feature(self):\n        return self.attn_feature_writer.get_cross_attention_feature()\n\n    def clear_writer_feature(self):\n        self.attn_feature_writer.clear_cross_attention_feature()\n\n    def disable_feature_adapters(self):\n        raise NotImplementedError\n\n    def set_reader_feature(self, features: list):\n        return self.attn_feature_reader.set_cross_attention_feature(features)\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(\n        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False\n    ):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor, _remove_lora=_remove_lora)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"), _remove_lora=_remove_lora)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnAddedKVProcessor()\n        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor, _remove_lora=True)\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    def enable_freeu(self, s1, s2, b1, b2):\n        r\"\"\"Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.\n\n        The suffixes after the scaling factors represent the stage blocks where they are being applied.\n\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that\n        are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate the \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate the \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        for i, upsample_block in enumerate(self.up_blocks):\n            setattr(upsample_block, \"s1\", s1)\n            setattr(upsample_block, \"s2\", s2)\n            setattr(upsample_block, \"b1\", b1)\n            setattr(upsample_block, \"b2\", b2)\n\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism.\"\"\"\n        freeu_keys = {\"s1\", \"s2\", \"b1\", \"b2\"}\n        for i, upsample_block in enumerate(self.up_blocks):\n            for k in freeu_keys:\n                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:\n                    setattr(upsample_block, k, None)\n\n    def fuse_qkv_projections(self):\n        \"\"\"\n        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,\n        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.\n\n        <Tip warning={true}>\n\n        This API is 🧪 experimental.\n\n        </Tip>\n        \"\"\"\n        self.original_attn_processors = None\n\n        for _, attn_processor in self.attn_processors.items():\n            if \"Added\" in str(attn_processor.__class__.__name__):\n                raise ValueError(\"`fuse_qkv_projections()` is not supported for models having added KV projections.\")\n\n        self.original_attn_processors = self.attn_processors\n\n        for module in self.modules():\n            if isinstance(module, Attention):\n                module.fuse_projections(fuse=True)\n\n    def unfuse_qkv_projections(self):\n        \"\"\"Disables the fused QKV projection if enabled.\n\n        <Tip warning={true}>\n\n        This API is 🧪 experimental.\n\n        </Tip>\n\n        \"\"\"\n        if self.original_attn_processors is not None:\n            self.set_attn_processor(self.original_attn_processors)\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        # import ipdb; ipdb.set_trace()\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):\n                Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed\n                through the `self.time_embedding` layer to obtain the timestep embeddings.\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n            down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):\n                A tuple of tensors that if specified are added to the residuals of down unet blocks.\n            mid_block_additional_residual: (`torch.Tensor`, *optional*):\n                A tensor that if specified is added to the residual of the middle unet block.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n            down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):\n                additional residuals to be added to UNet long skip connections from down blocks to up blocks for\n                example from ControlNet side model(s)\n            mid_block_additional_residual (`torch.Tensor`, *optional*):\n                additional residual to be added to UNet mid block output, for example from ControlNet side model\n            down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):\n                additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        for dim in sample.shape[-2:]:\n            if dim % default_overall_up_factor != 0:\n                # Forward upsample size to force interpolation output size.\n                forward_upsample_size = True\n                break\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"ip_image_proj\":\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            image_embeds = self.encoder_hid_proj(image_embeds)\n            if isinstance(image_embeds, list):\n                image_embeds = [image_embed.to(encoder_hidden_states.dtype) for image_embed in image_embeds]\n            else:\n                image_embeds = image_embeds.to(encoder_hidden_states.dtype)\n            encoder_hidden_states = (encoder_hidden_states, image_embeds)\n            # encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)\n        # import ipdb; ipdb.set_trace()\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 2.5 GLIGEN position net\n        if cross_attention_kwargs is not None and cross_attention_kwargs.get(\"gligen\", None) is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            gligen_args = cross_attention_kwargs.pop(\"gligen\")\n            cross_attention_kwargs[\"gligen\"] = {\"objs\": self.position_net(**gligen_args)}\n\n        # 3. down\n        lora_scale = cross_attention_kwargs.get(\"scale\", 1.0) if cross_attention_kwargs is not None else 1.0\n        if USE_PEFT_BACKEND:\n            # weight the lora layers by setting `lora_scale` for each PEFT layer\n            scale_lora_layers(self, lora_scale)\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets\n        is_adapter = down_intrablock_additional_residuals is not None\n        # maintain backward compatibility for legacy usage, where\n        #       T2I-Adapter and ControlNet both use down_block_additional_residuals arg\n        #       but can only use one or the other\n        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:\n            deprecate(\n                \"T2I should not use down_block_additional_residuals\",\n                \"1.3.0\",\n                \"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \\\n                       and will be removed in diffusers 1.3.0.  `down_block_additional_residuals` should only be used \\\n                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. \",\n                standard_warn=False,\n            )\n            down_intrablock_additional_residuals = down_block_additional_residuals\n            is_adapter = True\n        # import ipdb; ipdb.set_trace()\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_intrablock_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n                # import ipdb; ipdb.set_trace()\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    sample += down_intrablock_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n        # 4. mid\n        if self.mid_block is not None:\n            if hasattr(self.mid_block, \"has_cross_attention\") and self.mid_block.has_cross_attention:\n                sample = self.mid_block(\n                    sample,\n                    emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = self.mid_block(sample, emb)\n\n            # To support T2I-Adapter-XL\n            if (\n                is_adapter\n                and len(down_intrablock_additional_residuals) > 0\n                and sample.shape == down_intrablock_additional_residuals[0].shape\n            ):\n                sample += down_intrablock_additional_residuals.pop(0)\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n        # import ipdb; ipdb.set_trace()\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    upsample_size=upsample_size,\n                    scale=lora_scale,\n                )\n        # import ipdb; ipdb.set_trace()\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if USE_PEFT_BACKEND:\n            # remove `lora_scale` from each PEFT layer\n            unscale_lora_layers(self, lora_scale)\n\n        if not return_dict:\n            return (sample,)\n        # import ipdb; ipdb.set_trace()\n        return UNet2DConditionOutput(sample=sample)\n"
  },
  {
    "path": "foleycrafter/models/onset/__init__.py",
    "content": "from .r2plus1d_18 import r2plus1d18KeepTemp\nfrom .video_onset_net import VideoOnsetNet\n\n\n__all__ = [\"r2plus1d18KeepTemp\", \"VideoOnsetNet\"]\n"
  },
  {
    "path": "foleycrafter/models/onset/r2plus1d_18.py",
    "content": "# Copied from specvqgan/onset_baseline/models/r2plus1d_18.py\n\nimport torch\nimport torch.nn as nn\n\nfrom .resnet import r2plus1d_18\n\n\nclass r2plus1d18KeepTemp(nn.Module):\n    def __init__(self, pretrained=True):\n        super().__init__()\n\n        self.model = r2plus1d_18(pretrained=pretrained)\n\n        self.model.layer2[0].conv1[0][3] = nn.Conv3d(\n            230, 128, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False\n        )\n        self.model.layer2[0].downsample = nn.Sequential(\n            nn.Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),\n            nn.BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n        )\n        self.model.layer3[0].conv1[0][3] = nn.Conv3d(\n            460, 256, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False\n        )\n        self.model.layer3[0].downsample = nn.Sequential(\n            nn.Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),\n            nn.BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n        )\n        self.model.layer4[0].conv1[0][3] = nn.Conv3d(\n            921, 512, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False\n        )\n        self.model.layer4[0].downsample = nn.Sequential(\n            nn.Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),\n            nn.BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n        )\n        self.model.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))\n        self.model.fc = nn.Identity()\n\n    def forward(self, x):\n        # import pdb; pdb.set_trace()\n        x = self.model(x)\n        return x\n\n\nif __name__ == \"__main__\":\n    model = r2plus1d18KeepTemp(False).cuda()\n    rand_input = torch.randn((1, 3, 30, 112, 112)).cuda()\n    out = model(rand_input)\n"
  },
  {
    "path": "foleycrafter/models/onset/resnet.py",
    "content": "# Copied from specvqgan/onset_baseline/models/resnet.py\nimport torch.nn as nn\nfrom torch.hub import load_state_dict_from_url\n\n\n__all__ = [\"r3d_18\", \"mc3_18\", \"r2plus1d_18\"]\n\nmodel_urls = {\n    \"r3d_18\": \"https://download.pytorch.org/models/r3d_18-b3b3357e.pth\",\n    \"mc3_18\": \"https://download.pytorch.org/models/mc3_18-a90a0ba3.pth\",\n    \"r2plus1d_18\": \"https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth\",\n}\n\n\nclass Conv3DSimple(nn.Conv3d):\n    def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1):\n        super(Conv3DSimple, self).__init__(\n            in_channels=in_planes,\n            out_channels=out_planes,\n            kernel_size=(3, 3, 3),\n            stride=stride,\n            padding=padding,\n            bias=False,\n        )\n\n    @staticmethod\n    def get_downsample_stride(stride):\n        return stride, stride, stride\n\n\nclass Conv2Plus1D(nn.Sequential):\n    def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1):\n        super(Conv2Plus1D, self).__init__(\n            nn.Conv3d(\n                in_planes,\n                midplanes,\n                kernel_size=(1, 3, 3),\n                stride=(1, stride, stride),\n                padding=(0, padding, padding),\n                bias=False,\n            ),\n            nn.BatchNorm3d(midplanes),\n            nn.ReLU(inplace=True),\n            nn.Conv3d(\n                midplanes,\n                out_planes,\n                kernel_size=(3, 1, 1),\n                stride=(stride, 1, 1),\n                padding=(padding, 0, 0),\n                bias=False,\n            ),\n        )\n\n    @staticmethod\n    def get_downsample_stride(stride):\n        return stride, stride, stride\n\n\nclass Conv3DNoTemporal(nn.Conv3d):\n    def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1):\n        super(Conv3DNoTemporal, self).__init__(\n            in_channels=in_planes,\n            out_channels=out_planes,\n            kernel_size=(1, 3, 3),\n            stride=(1, stride, stride),\n            padding=(0, padding, padding),\n            bias=False,\n        )\n\n    @staticmethod\n    def get_downsample_stride(stride):\n        return 1, stride, stride\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):\n        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)\n\n        super(BasicBlock, self).__init__()\n        self.conv1 = nn.Sequential(\n            conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)\n        )\n        self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.conv2(out)\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)\n\n        # 1x1x1\n        self.conv1 = nn.Sequential(\n            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)\n        )\n        # Second kernel\n        self.conv2 = nn.Sequential(\n            conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)\n        )\n\n        # 1x1x1\n        self.conv3 = nn.Sequential(\n            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),\n            nn.BatchNorm3d(planes * self.expansion),\n        )\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.conv2(out)\n        out = self.conv3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass BasicStem(nn.Sequential):\n    \"\"\"The default conv-batchnorm-relu stem\"\"\"\n\n    def __init__(self):\n        super(BasicStem, self).__init__(\n            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),\n            nn.BatchNorm3d(64),\n            nn.ReLU(inplace=True),\n        )\n\n\nclass R2Plus1dStem(nn.Sequential):\n    \"\"\"R(2+1)D stem is different than the default one as it uses separated 3D convolution\"\"\"\n\n    def __init__(self):\n        super(R2Plus1dStem, self).__init__(\n            nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),\n            nn.BatchNorm3d(45),\n            nn.ReLU(inplace=True),\n            nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),\n            nn.BatchNorm3d(64),\n            nn.ReLU(inplace=True),\n        )\n\n\nclass VideoResNet(nn.Module):\n    def __init__(self, block, conv_makers, layers, stem, num_classes=400, zero_init_residual=False):\n        \"\"\"Generic resnet video generator.\n        Args:\n            block (nn.Module): resnet building block\n            conv_makers (list(functions)): generator function for each layer\n            layers (List[int]): number of blocks per layer\n            stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.\n            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.\n            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.\n        \"\"\"\n        super(VideoResNet, self).__init__()\n        self.inplanes = 64\n\n        self.stem = stem()\n\n        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)\n        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)\n\n        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        # init weights\n        self._initialize_weights()\n\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n\n    def forward(self, x):\n        x = self.stem(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        # Flatten the layer to fc\n        # x = x.flatten(1)\n        # x = self.fc(x)\n        N = x.shape[0]\n        x = x.squeeze()\n        if N == 1:\n            x = x[None]\n\n        return x\n\n    def _make_layer(self, block, conv_builder, planes, blocks, stride=1):\n        downsample = None\n\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            ds_stride = conv_builder.get_downsample_stride(stride)\n            downsample = nn.Sequential(\n                nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),\n                nn.BatchNorm3d(planes * block.expansion),\n            )\n        layers = []\n        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))\n\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, conv_builder))\n\n        return nn.Sequential(*layers)\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv3d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm3d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.constant_(m.bias, 0)\n\n\ndef _video_resnet(arch, pretrained=False, progress=True, **kwargs):\n    model = VideoResNet(**kwargs)\n\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n\n\ndef r3d_18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Construct 18 layer Resnet3D model as in\n    https://arxiv.org/abs/1711.11248\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Kinetics-400\n        progress (bool): If True, displays a progress bar of the download to stderr\n    Returns:\n        nn.Module: R3D-18 network\n    \"\"\"\n\n    return _video_resnet(\n        \"r3d_18\",\n        pretrained,\n        progress,\n        block=BasicBlock,\n        conv_makers=[Conv3DSimple] * 4,\n        layers=[2, 2, 2, 2],\n        stem=BasicStem,\n        **kwargs,\n    )\n\n\ndef mc3_18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructor for 18 layer Mixed Convolution network as in\n    https://arxiv.org/abs/1711.11248\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Kinetics-400\n        progress (bool): If True, displays a progress bar of the download to stderr\n    Returns:\n        nn.Module: MC3 Network definition\n    \"\"\"\n    return _video_resnet(\n        \"mc3_18\",\n        pretrained,\n        progress,\n        block=BasicBlock,\n        conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,\n        layers=[2, 2, 2, 2],\n        stem=BasicStem,\n        **kwargs,\n    )\n\n\ndef r2plus1d_18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructor for the 18 layer deep R(2+1)D network as in\n    https://arxiv.org/abs/1711.11248\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Kinetics-400\n        progress (bool): If True, displays a progress bar of the download to stderr\n    Returns:\n        nn.Module: R(2+1)D-18 network\n    \"\"\"\n    return _video_resnet(\n        \"r2plus1d_18\",\n        pretrained,\n        progress,\n        block=BasicBlock,\n        conv_makers=[Conv2Plus1D] * 4,\n        layers=[2, 2, 2, 2],\n        stem=R2Plus1dStem,\n        **kwargs,\n    )\n"
  },
  {
    "path": "foleycrafter/models/onset/torch_utils.py",
    "content": "# Copied from https://github.com/XYPB/CondFoleyGen/blob/main/specvqgan/onset_baseline/utils/torch_utils.py\nimport os\nimport sys\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom ... import data\n\n\n# ---------------------------------------------------- #\ndef load_model(cp_path, net, device=None, strict=True):\n    if not device:\n        device = torch.device(\"cpu\")\n    if os.path.isfile(cp_path):\n        print(\"=> loading checkpoint '{}'\".format(cp_path))\n        checkpoint = torch.load(cp_path, map_location=device)\n\n        # check if there is module\n        if list(checkpoint[\"state_dict\"].keys())[0][:7] == \"module.\":\n            state_dict = OrderedDict()\n            for k, v in checkpoint[\"state_dict\"].items():\n                name = k[7:]\n                state_dict[name] = v\n        else:\n            state_dict = checkpoint[\"state_dict\"]\n        net.load_state_dict(state_dict, strict=strict)\n\n        print(\"=> loaded checkpoint '{}' (epoch {})\".format(cp_path, checkpoint[\"epoch\"]))\n        start_epoch = checkpoint[\"epoch\"]\n    else:\n        print(\"=> no checkpoint found at '{}'\".format(cp_path))\n        start_epoch = 0\n        sys.exit()\n\n    return net, start_epoch\n\n\n# ---------------------------------------------------- #\ndef binary_acc(pred, target, threshold):\n    pred = pred > threshold\n    acc = np.sum(pred == target) / target.shape[0]\n    return acc\n\n\ndef calc_acc(prob, labels, k):\n    pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]\n    top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0)\n    return top_k_acc\n\n\n# ---------------------------------------------------- #\n\n\ndef get_dataloader(args, pr, split=\"train\", shuffle=False, drop_last=False, batch_size=None):\n    data_loader = getattr(data, pr.dataloader)\n    if split == \"train\":\n        read_list = pr.list_train\n    elif split == \"val\":\n        read_list = pr.list_val\n    elif split == \"test\":\n        read_list = pr.list_test\n    dataset = data_loader(args, pr, read_list, split=split)\n    batch_size = batch_size if batch_size else args.batch_size\n    dataset.getitem_test(1)\n    loader = DataLoader(\n        dataset,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        num_workers=args.num_workers,\n        pin_memory=True,\n        drop_last=drop_last,\n    )\n\n    return dataset, loader\n\n\n# ---------------------------------------------------- #\ndef make_optimizer(model, args):\n    \"\"\"\n    Args:\n        model: NN to train\n    Returns:\n        optimizer: pytorch optmizer for updating the given model parameters.\n    \"\"\"\n    if args.optim == \"SGD\":\n        optimizer = torch.optim.SGD(\n            filter(lambda p: p.requires_grad, model.parameters()),\n            lr=args.lr,\n            momentum=args.momentum,\n            weight_decay=args.weight_decay,\n            nesterov=False,\n        )\n    elif args.optim == \"Adam\":\n        optimizer = torch.optim.Adam(\n            filter(lambda p: p.requires_grad, model.parameters()),\n            lr=args.lr,\n            weight_decay=args.weight_decay,\n        )\n    return optimizer\n\n\ndef adjust_learning_rate(optimizer, epoch, args):\n    \"\"\"Decay the learning rate based on schedule\"\"\"\n    lr = args.lr\n    if args.schedule == \"cos\":  # cosine lr schedule\n        lr *= 0.5 * (1.0 + np.cos(np.pi * epoch / args.epochs))\n    elif args.schedule == \"none\":  # no lr schedule\n        lr = args.lr\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n"
  },
  {
    "path": "foleycrafter/models/onset/video_onset_net.py",
    "content": "# Copied from specvqgan/onset_baseline/models/video_onset_net.py\n\nimport torch\nimport torch.nn as nn\n\nfrom .r2plus1d_18 import r2plus1d18KeepTemp\n\n\nclass VideoOnsetNet(nn.Module):\n    # Video Onset detection network\n    def __init__(self, pretrained):\n        super(VideoOnsetNet, self).__init__()\n        self.net = r2plus1d18KeepTemp(pretrained=pretrained)\n        self.fc = nn.Sequential(nn.Linear(512, 128), nn.ReLU(True), nn.Linear(128, 1))\n\n    def forward(self, inputs, loss=False, evaluate=False):\n        # import pdb; pdb.set_trace()\n        x = inputs[\"frames\"]\n        x = self.net(x)\n        x = x.transpose(-1, -2)\n        x = self.fc(x)\n        x = x.squeeze(-1)\n\n        return x\n\n\nif __name__ == \"__main__\":\n    model = VideoOnsetNet(False).cuda()\n    rand_input = torch.randn((1, 3, 30, 112, 112)).cuda()\n    inputs = {\"frames\": rand_input}\n    out = model(inputs)\n"
  },
  {
    "path": "foleycrafter/models/time_detector/model.py",
    "content": "import torch.nn as nn\n\nfrom ..onset import VideoOnsetNet\n\n\nclass TimeDetector(nn.Module):\n    def __init__(self, video_length=150, audio_length=1024):\n        super(TimeDetector, self).__init__()\n        self.pred_net = VideoOnsetNet(pretrained=False)\n        self.soft_fn = nn.Tanh()\n        self.up_sampler = nn.Linear(video_length, audio_length)\n\n    def forward(self, inputs):\n        x = self.pred_net(inputs)\n        x = self.up_sampler(x)\n        x = self.soft_fn(x)\n        return x\n"
  },
  {
    "path": "foleycrafter/models/time_detector/resnet.py",
    "content": "import torch.nn as nn\nfrom torch.hub import load_state_dict_from_url\n\n\n__all__ = [\"r3d_18\", \"mc3_18\", \"r2plus1d_18\"]\n\nmodel_urls = {\n    \"r3d_18\": \"https://download.pytorch.org/models/r3d_18-b3b3357e.pth\",\n    \"mc3_18\": \"https://download.pytorch.org/models/mc3_18-a90a0ba3.pth\",\n    \"r2plus1d_18\": \"https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth\",\n}\n\n\nclass Conv3DSimple(nn.Conv3d):\n    def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1):\n        super(Conv3DSimple, self).__init__(\n            in_channels=in_planes,\n            out_channels=out_planes,\n            kernel_size=(3, 3, 3),\n            stride=stride,\n            padding=padding,\n            bias=False,\n        )\n\n    @staticmethod\n    def get_downsample_stride(stride):\n        return stride, stride, stride\n\n\nclass Conv2Plus1D(nn.Sequential):\n    def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1):\n        super(Conv2Plus1D, self).__init__(\n            nn.Conv3d(\n                in_planes,\n                midplanes,\n                kernel_size=(1, 3, 3),\n                stride=(1, stride, stride),\n                padding=(0, padding, padding),\n                bias=False,\n            ),\n            nn.BatchNorm3d(midplanes),\n            nn.ReLU(inplace=True),\n            nn.Conv3d(\n                midplanes,\n                out_planes,\n                kernel_size=(3, 1, 1),\n                stride=(stride, 1, 1),\n                padding=(padding, 0, 0),\n                bias=False,\n            ),\n        )\n\n    @staticmethod\n    def get_downsample_stride(stride):\n        return stride, stride, stride\n\n\nclass Conv3DNoTemporal(nn.Conv3d):\n    def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1):\n        super(Conv3DNoTemporal, self).__init__(\n            in_channels=in_planes,\n            out_channels=out_planes,\n            kernel_size=(1, 3, 3),\n            stride=(1, stride, stride),\n            padding=(0, padding, padding),\n            bias=False,\n        )\n\n    @staticmethod\n    def get_downsample_stride(stride):\n        return 1, stride, stride\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):\n        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)\n\n        super(BasicBlock, self).__init__()\n        self.conv1 = nn.Sequential(\n            conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)\n        )\n        self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.conv2(out)\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)\n\n        # 1x1x1\n        self.conv1 = nn.Sequential(\n            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)\n        )\n        # Second kernel\n        self.conv2 = nn.Sequential(\n            conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)\n        )\n\n        # 1x1x1\n        self.conv3 = nn.Sequential(\n            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),\n            nn.BatchNorm3d(planes * self.expansion),\n        )\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.conv2(out)\n        out = self.conv3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass BasicStem(nn.Sequential):\n    \"\"\"The default conv-batchnorm-relu stem\"\"\"\n\n    def __init__(self):\n        super(BasicStem, self).__init__(\n            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),\n            nn.BatchNorm3d(64),\n            nn.ReLU(inplace=True),\n        )\n\n\nclass R2Plus1dStem(nn.Sequential):\n    \"\"\"R(2+1)D stem is different than the default one as it uses separated 3D convolution\"\"\"\n\n    def __init__(self):\n        super(R2Plus1dStem, self).__init__(\n            nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),\n            nn.BatchNorm3d(45),\n            nn.ReLU(inplace=True),\n            nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),\n            nn.BatchNorm3d(64),\n            nn.ReLU(inplace=True),\n        )\n\n\nclass VideoResNet(nn.Module):\n    def __init__(self, block, conv_makers, layers, stem, num_classes=400, zero_init_residual=False):\n        \"\"\"Generic resnet video generator.\n        Args:\n            block (nn.Module): resnet building block\n            conv_makers (list(functions)): generator function for each layer\n            layers (List[int]): number of blocks per layer\n            stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.\n            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.\n            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.\n        \"\"\"\n        super(VideoResNet, self).__init__()\n        self.inplanes = 64\n\n        self.stem = stem()\n\n        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)\n        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)\n\n        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        # init weights\n        self._initialize_weights()\n\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)\n\n    def forward(self, x):\n        x = self.stem(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        # Flatten the layer to fc\n        # x = x.flatten(1)\n        # x = self.fc(x)\n        N = x.shape[0]\n        x = x.squeeze()\n        if N == 1:\n            x = x[None]\n\n        return x\n\n    def _make_layer(self, block, conv_builder, planes, blocks, stride=1):\n        downsample = None\n\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            ds_stride = conv_builder.get_downsample_stride(stride)\n            downsample = nn.Sequential(\n                nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),\n                nn.BatchNorm3d(planes * block.expansion),\n            )\n        layers = []\n        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))\n\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, conv_builder))\n\n        return nn.Sequential(*layers)\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv3d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm3d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.constant_(m.bias, 0)\n\n\ndef _video_resnet(arch, pretrained=False, progress=True, **kwargs):\n    model = VideoResNet(**kwargs)\n\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)\n        model.load_state_dict(state_dict)\n    return model\n\n\ndef r3d_18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Construct 18 layer Resnet3D model as in\n    https://arxiv.org/abs/1711.11248\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Kinetics-400\n        progress (bool): If True, displays a progress bar of the download to stderr\n    Returns:\n        nn.Module: R3D-18 network\n    \"\"\"\n\n    return _video_resnet(\n        \"r3d_18\",\n        pretrained,\n        progress,\n        block=BasicBlock,\n        conv_makers=[Conv3DSimple] * 4,\n        layers=[2, 2, 2, 2],\n        stem=BasicStem,\n        **kwargs,\n    )\n\n\ndef mc3_18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructor for 18 layer Mixed Convolution network as in\n    https://arxiv.org/abs/1711.11248\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Kinetics-400\n        progress (bool): If True, displays a progress bar of the download to stderr\n    Returns:\n        nn.Module: MC3 Network definition\n    \"\"\"\n    return _video_resnet(\n        \"mc3_18\",\n        pretrained,\n        progress,\n        block=BasicBlock,\n        conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,\n        layers=[2, 2, 2, 2],\n        stem=BasicStem,\n        **kwargs,\n    )\n\n\ndef r2plus1d_18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructor for the 18 layer deep R(2+1)D network as in\n    https://arxiv.org/abs/1711.11248\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on Kinetics-400\n        progress (bool): If True, displays a progress bar of the download to stderr\n    Returns:\n        nn.Module: R(2+1)D-18 network\n    \"\"\"\n    return _video_resnet(\n        \"r2plus1d_18\",\n        pretrained,\n        progress,\n        block=BasicBlock,\n        conv_makers=[Conv2Plus1D] * 4,\n        layers=[2, 2, 2, 2],\n        stem=R2Plus1dStem,\n        **kwargs,\n    )\n"
  },
  {
    "path": "foleycrafter/pipelines/auffusion_pipeline.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport json\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom huggingface_hub import snapshot_download\nfrom packaging import version\nfrom torch.nn import Conv1d, ConvTranspose1d\nfrom torch.nn.utils import remove_weight_norm, weight_norm\nfrom transformers import (\n    AutoTokenizer,\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n    PretrainedConfig,\n)\n\nfrom diffusers import PNDMScheduler\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ImageProjection\nfrom diffusers.models.attention_processor import FusedAttnProcessor2_0\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    is_accelerate_available,\n    is_accelerate_version,\n    logging,\n)\nfrom diffusers.utils.outputs import BaseOutput\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom foleycrafter.models.auffusion.loaders.ip_adapter import IPAdapterMixin\nfrom foleycrafter.models.auffusion_unet import UNet2DConditionModel\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef json_dump(data_json, json_save_path):\n    with open(json_save_path, \"w\") as f:\n        json.dump(data_json, f, indent=4)\n        f.close()\n\n\ndef json_load(json_path):\n    with open(json_path, \"r\") as f:\n        data = json.load(f)\n        f.close()\n    return data\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    if \"t5\" in model_class.lower():\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    if \"clap\" in model_class.lower():\n        from transformers import ClapTextModelWithProjection\n\n        return ClapTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\nclass ConditionAdapter(nn.Module):\n    def __init__(self, config):\n        super(ConditionAdapter, self).__init__()\n        self.config = config\n        self.proj = nn.Linear(self.config[\"condition_dim\"], self.config[\"cross_attention_dim\"])\n        self.norm = torch.nn.LayerNorm(self.config[\"cross_attention_dim\"])\n        print(f\"INITIATED: ConditionAdapter: {self.config}\")\n\n    def forward(self, x):\n        x = self.proj(x)\n        x = self.norm(x)\n        return x\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path):\n        config_path = os.path.join(pretrained_model_name_or_path, \"config.json\")\n        ckpt_path = os.path.join(pretrained_model_name_or_path, \"condition_adapter.pt\")\n        config = json.loads(open(config_path).read())\n        instance = cls(config)\n        instance.load_state_dict(torch.load(ckpt_path))\n        print(f\"LOADED: ConditionAdapter from {pretrained_model_name_or_path}\")\n        return instance\n\n    def save_pretrained(self, pretrained_model_name_or_path):\n        os.makedirs(pretrained_model_name_or_path, exist_ok=True)\n        config_path = os.path.join(pretrained_model_name_or_path, \"config.json\")\n        ckpt_path = os.path.join(pretrained_model_name_or_path, \"condition_adapter.pt\")\n        json_dump(self.config, config_path)\n        torch.save(self.state_dict(), ckpt_path)\n        print(f\"SAVED: ConditionAdapter {self.config['model_name']} to {pretrained_model_name_or_path}\")\n\n\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\nLRELU_SLOPE = 0.1\nMAX_WAV_VALUE = 32768.0\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n\n\ndef get_config(config_path):\n    config = json.loads(open(config_path).read())\n    config = AttrDict(config)\n    return config\n\n\ndef init_weights(m, mean=0.0, std=0.01):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(mean, std)\n\n\ndef apply_weight_norm(m):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        weight_norm(m)\n\n\ndef get_padding(kernel_size, dilation=1):\n    return int((kernel_size * dilation - dilation) / 2)\n\n\nclass ResBlock1(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):\n        super(ResBlock1, self).__init__()\n        self.h = h\n        self.convs1 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[2],\n                        padding=get_padding(kernel_size, dilation[2]),\n                    )\n                ),\n            ]\n        )\n        self.convs1.apply(init_weights)\n\n        self.convs2 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))\n                ),\n                weight_norm(\n                    Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))\n                ),\n                weight_norm(\n                    Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))\n                ),\n            ]\n        )\n        self.convs2.apply(init_weights)\n\n    def forward(self, x):\n        for c1, c2 in zip(self.convs1, self.convs2):\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c1(xt)\n            xt = F.leaky_relu(xt, LRELU_SLOPE)\n            xt = c2(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)\n\n\nclass ResBlock2(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):\n        super(ResBlock2, self).__init__()\n        self.h = h\n        self.convs = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n            ]\n        )\n        self.convs.apply(init_weights)\n\n    def forward(self, x):\n        for c in self.convs:\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs:\n            remove_weight_norm(l)\n\n\nclass Generator(torch.nn.Module):\n    def __init__(self, h):\n        super(Generator, self).__init__()\n        self.h = h\n        self.num_kernels = len(h.resblock_kernel_sizes)\n        self.num_upsamples = len(h.upsample_rates)\n        # self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))\n        self.conv_pre = weight_norm(\n            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)\n        )  # change: 80 --> 512\n        resblock = ResBlock1 if h.resblock == \"1\" else ResBlock2\n\n        self._device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        self.ups = nn.ModuleList()\n        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):\n            if (k - u) % 2 == 0:\n                self.ups.append(\n                    weight_norm(\n                        ConvTranspose1d(\n                            h.upsample_initial_channel // (2**i),\n                            h.upsample_initial_channel // (2 ** (i + 1)),\n                            k,\n                            u,\n                            padding=(k - u) // 2,\n                        )\n                    )\n                )\n            else:\n                self.ups.append(\n                    weight_norm(\n                        ConvTranspose1d(\n                            h.upsample_initial_channel // (2**i),\n                            h.upsample_initial_channel // (2 ** (i + 1)),\n                            k,\n                            u,\n                            padding=(k - u) // 2 + 1,\n                            output_padding=1,\n                        )\n                    )\n                )\n\n            # self.ups.append(weight_norm(\n            #     ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),\n            #                     k, u, padding=(k-u)//2)))\n\n        self.resblocks = nn.ModuleList()\n        for i in range(len(self.ups)):\n            ch = h.upsample_initial_channel // (2 ** (i + 1))\n            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):\n                self.resblocks.append(resblock(h, ch, k, d))\n\n        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))\n        self.ups.apply(init_weights)\n        self.conv_post.apply(init_weights)\n\n    @property\n    def device(self) -> torch.device:\n        return torch.device(self._device)\n\n    @property\n    def dtype(self):\n        return self.type\n\n    def forward(self, x):\n        x = self.conv_pre(x)\n        for i in range(self.num_upsamples):\n            x = F.leaky_relu(x, LRELU_SLOPE)\n            x = self.ups[i](x)\n            xs = None\n            for j in range(self.num_kernels):\n                if xs is None:\n                    xs = self.resblocks[i * self.num_kernels + j](x)\n                else:\n                    xs += self.resblocks[i * self.num_kernels + j](x)\n            x = xs / self.num_kernels\n        x = F.leaky_relu(x)\n        x = self.conv_post(x)\n        x = torch.tanh(x)\n\n        return x\n\n    def remove_weight_norm(self):\n        print(\"Removing weight norm...\")\n        for l in self.ups:\n            remove_weight_norm(l)\n        for l in self.resblocks:\n            l.remove_weight_norm()\n        remove_weight_norm(self.conv_pre)\n        remove_weight_norm(self.conv_post)\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):\n        if subfolder is not None:\n            pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)\n        config_path = os.path.join(pretrained_model_name_or_path, \"config.json\")\n        ckpt_path = os.path.join(pretrained_model_name_or_path, \"vocoder.pt\")\n\n        config = get_config(config_path)\n        vocoder = cls(config)\n\n        state_dict_g = torch.load(ckpt_path)\n        vocoder.load_state_dict(state_dict_g[\"generator\"])\n        vocoder.eval()\n        vocoder.remove_weight_norm()\n        return vocoder\n\n    @torch.no_grad()\n    def inference(self, mels, lengths=None):\n        self.eval()\n        with torch.no_grad():\n            wavs = self(mels).squeeze(1)\n\n        wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype(\"int16\")\n\n        if lengths is not None:\n            wavs = wavs[:, :lengths]\n\n        return wavs\n\n\ndef normalize_spectrogram(\n    spectrogram: torch.Tensor,\n    max_value: float = 200,\n    min_value: float = 1e-5,\n    power: float = 1.0,\n) -> torch.Tensor:\n    # Rescale to 0-1\n    max_value = np.log(max_value)  # 5.298317366548036\n    min_value = np.log(min_value)  # -11.512925464970229\n    spectrogram = torch.clamp(spectrogram, min=min_value, max=max_value)\n    data = (spectrogram - min_value) / (max_value - min_value)\n    # Apply the power curve\n    data = torch.pow(data, power)\n    # 1D -> 3D\n    data = data.repeat(3, 1, 1)\n    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner\n    data = torch.flip(data, [1])\n\n    return data\n\n\ndef denormalize_spectrogram(\n    data: torch.Tensor,\n    max_value: float = 200,\n    min_value: float = 1e-5,\n    power: float = 1,\n) -> torch.Tensor:\n    assert len(data.shape) == 3, \"Expected 3 dimensions, got {}\".format(len(data.shape))\n\n    max_value = np.log(max_value)\n    min_value = np.log(min_value)\n    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner\n    data = torch.flip(data, [1])\n    if data.shape[0] == 1:\n        data = data.repeat(3, 1, 1)\n    assert data.shape[0] == 3, \"Expected 3 channels, got {}\".format(data.shape[0])\n    data = data[0]\n    # Reverse the power curve\n    data = torch.pow(data, 1 / power)\n    # Rescale to max value\n    spectrogram = data * (max_value - min_value) + min_value\n\n    return spectrogram\n\n\n@staticmethod\ndef pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:\n    \"\"\"\n    Convert a PyTorch tensor to a NumPy image.\n    \"\"\"\n    images = images.cpu().permute(0, 2, 3, 1).float().numpy()\n    return images\n\n\n@staticmethod\ndef numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:\n    \"\"\"\n    Convert a numpy image or a batch of images to a PIL image.\n    \"\"\"\n    if images.ndim == 3:\n        images = images[None, ...]\n    images = (images * 255).round().astype(\"uint8\")\n    if images.shape[-1] == 1:\n        # special case for grayscale (single channel) images\n        pil_images = [PIL.Image.fromarray(image.squeeze(), mode=\"L\") for image in images]\n    else:\n        pil_images = [PIL.Image.fromarray(image) for image in images]\n\n    return pil_images\n\n\ndef image_add_color(spec_img):\n    cmap = plt.get_cmap(\"viridis\")\n    # cmap_r = cmap.reversed()\n    image = cmap(np.array(spec_img)[:, :, 0])[:, :, :3]  # 省略透明度通道\n    image = (image - image.min()) / (image.max() - image.min())\n    image = PIL.Image.fromarray(np.uint8(image * 255))\n    return image\n\n\n@dataclass\nclass PipelineOutput(BaseOutput):\n    \"\"\"\n    Output class for audio pipelines.\n\n    Args:\n        audios (`np.ndarray`)\n            List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`.\n    \"\"\"\n\n    images: Union[List[PIL.Image.Image], np.ndarray]\n    spectrograms: Union[List[np.ndarray], np.ndarray]\n    audios: Union[List[np.ndarray], np.ndarray]\n\n\nclass AuffusionPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    In addition the pipeline inherits the following loading methods:\n        - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]\n        - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]\n        - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]\n\n    as well as the following saving methods:\n        - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\n        \"safety_checker\",\n        \"feature_extractor\",\n        \"text_encoder_list\",\n        \"tokenizer_list\",\n        \"adapter_list\",\n        \"vocoder\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        text_encoder_list: Optional[List[Callable]] = None,\n        tokenizer_list: Optional[List[Callable]] = None,\n        vocoder: Generator = None,\n        requires_safety_checker: bool = False,\n        adapter_list: Optional[List[Callable]] = None,\n        tokenizer_model_max_length: Optional[\n            int\n        ] = 77,  # 77 is the default value for the CLIPTokenizer(and set for other models)\n    ):\n        super().__init__()\n\n        self.text_encoder_list = text_encoder_list\n        self.tokenizer_list = tokenizer_list\n        self.vocoder = vocoder\n        self.adapter_list = adapter_list\n        self.tokenizer_model_max_length = tokenizer_model_max_length\n\n        self.register_modules(\n            vae=vae,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        pretrained_model_name_or_path: str = \"auffusion/auffusion-full-no-adapter\",\n        dtype: torch.dtype = torch.float16,\n        device: str = \"cuda\",\n    ):\n        if not os.path.isdir(pretrained_model_name_or_path):\n            pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)\n\n        vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n        unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"unet\")\n        feature_extractor = CLIPImageProcessor.from_pretrained(\n            pretrained_model_name_or_path, subfolder=\"feature_extractor\"\n        )\n        scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n        vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder=\"vocoder\").to(device, dtype)\n\n        text_encoder_list, tokenizer_list, adapter_list = [], [], []\n\n        condition_json_path = os.path.join(pretrained_model_name_or_path, \"condition_config.json\")\n        condition_json_list = json.loads(open(condition_json_path).read())\n\n        for i, condition_item in enumerate(condition_json_list):\n            # Load Condition Adapter\n            text_encoder_path = os.path.join(pretrained_model_name_or_path, condition_item[\"text_encoder_name\"])\n            tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)\n            tokenizer_list.append(tokenizer)\n            text_encoder_cls = import_model_class_from_model_name_or_path(text_encoder_path)\n            text_encoder = text_encoder_cls.from_pretrained(text_encoder_path).to(device, dtype)\n            text_encoder_list.append(text_encoder)\n            print(f\"LOADING CONDITION ENCODER {i}\")\n\n            # Load Condition Adapter\n            adapter_path = os.path.join(pretrained_model_name_or_path, condition_item[\"condition_adapter_name\"])\n            adapter = ConditionAdapter.from_pretrained(adapter_path).to(device, dtype)\n            adapter_list.append(adapter)\n            print(f\"LOADING CONDITION ADAPTER {i}\")\n\n        pipeline = cls(\n            vae=vae,\n            unet=unet,\n            text_encoder_list=text_encoder_list,\n            tokenizer_list=tokenizer_list,\n            vocoder=vocoder,\n            adapter_list=adapter_list,\n            scheduler=scheduler,\n            safety_checker=None,\n            feature_extractor=feature_extractor,\n        )\n        pipeline = pipeline.to(device, dtype)\n\n        return pipeline\n\n    def to(self, device, dtype=None):\n        super().to(device, dtype)\n\n        self.vocoder.to(device, dtype)\n\n        for text_encoder in self.text_encoder_list:\n            text_encoder.to(device, dtype)\n\n        if self.adapter_list is not None:\n            for adapter in self.adapter_list:\n                adapter.to(device, dtype)\n\n        return self\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding.\n\n        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several\n        steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in\n        several steps. This is useful to save a large amount of memory and to allow the processing of larger images.\n        \"\"\"\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_tiling()\n\n    def enable_sequential_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,\n        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a\n        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.\n        Note that offloading happens on a submodule basis. Memory savings are higher than with\n        `enable_model_cpu_offload`, but performance is lower.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.14.0\"):\n            from accelerate import cpu_offload\n        else:\n            raise ImportError(\"`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\", silence_dtype_warnings=True)\n            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)\n\n        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:\n            cpu_offload(cpu_offloaded_model, device)\n\n        if self.safety_checker is not None:\n            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)\n\n    def enable_model_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared\n        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`\n        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with\n        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate import cpu_offload_with_hook\n        else:\n            raise ImportError(\"`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\", silence_dtype_warnings=True)\n            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)\n\n        hook = None\n        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:\n            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)\n\n        if self.safety_checker is not None:\n            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)\n\n        # We'll offload the last model manually.\n        self.final_offload_hook = hook\n\n    @property\n    def _execution_device(self):\n        r\"\"\"\n        Returns the device on which the pipeline's models will be executed. After calling\n        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module\n        hooks.\n        \"\"\"\n        if not hasattr(self.unet, \"_hf_hook\"):\n            return self.device\n        for module in self.unet.modules():\n            if (\n                hasattr(module, \"_hf_hook\")\n                and hasattr(module._hf_hook, \"execution_device\")\n                and module._hf_hook.execution_device is not None\n            ):\n                return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        assert len(self.text_encoder_list) == len(\n            self.tokenizer_list\n        ), \"Number of text_encoders must match number of tokenizers\"\n        if self.adapter_list is not None:\n            assert len(self.text_encoder_list) == len(\n                self.adapter_list\n            ), \"Number of text_encoders must match number of adapters\"\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        def get_prompt_embeds(prompt_list, device):\n            if isinstance(prompt_list, str):\n                prompt_list = [prompt_list]\n\n            prompt_embeds_list = []\n            for prompt in prompt_list:\n                encoder_hidden_states_list = []\n\n                # Generate condition embedding\n                for j in range(len(self.text_encoder_list)):\n                    # get condition embedding using condition encoder\n                    input_ids = self.tokenizer_list[j](prompt, return_tensors=\"pt\").input_ids.to(device)\n                    cond_embs = self.text_encoder_list[j](input_ids).last_hidden_state  # [bz, text_len, text_dim]\n                    # padding to max_length\n                    if cond_embs.shape[1] < self.tokenizer_model_max_length:\n                        cond_embs = torch.functional.F.pad(\n                            cond_embs, (0, 0, 0, self.tokenizer_model_max_length - cond_embs.shape[1]), value=0\n                        )\n                    else:\n                        cond_embs = cond_embs[:, : self.tokenizer_model_max_length, :]\n\n                    # use condition adapter\n                    if self.adapter_list is not None:\n                        cond_embs = self.adapter_list[j](cond_embs)\n                        encoder_hidden_states_list.append(cond_embs)\n\n                prompt_embeds = torch.cat(encoder_hidden_states_list, dim=1)\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.cat(prompt_embeds_list, dim=0)\n            return prompt_embeds\n\n        if prompt_embeds is None:\n            prompt_embeds = get_prompt_embeds(prompt, device)\n\n        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            if negative_prompt is None:\n                negative_prompt_embeds = torch.zeros_like(prompt_embeds).to(dtype=prompt_embeds.dtype, device=device)\n\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                negative_prompt = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                negative_prompt_embeds = get_prompt_embeds(negative_prompt, device)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        warnings.warn(\n            \"The decode_latents method is deprecated and will be removed in a future version. Please\"\n            \" use VaeImageProcessor instead\",\n            FutureWarning,\n        )\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = 256,\n        width: Optional[int] = 1024,\n        num_inference_steps: int = 100,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pt\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        duration: Optional[float] = 10,\n    ):\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        audio_length = int(duration * 16000)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        # Generate audio\n        spectrograms, audios = [], []\n        for img in image:\n            spectrogram = denormalize_spectrogram(img)\n            audio = self.vocoder.inference(spectrogram, lengths=audio_length)[0]\n            audios.append(audio)\n            spectrograms.append(spectrogram)\n\n        # Convert to PIL\n        images = pt_to_numpy(image)\n        images = numpy_to_pil(images)\n        images = [image_add_color(image) for image in images]\n\n        if not return_dict:\n            return (images, audios, spectrograms)\n\n        return PipelineOutput(images=images, audios=audios, spectrograms=spectrograms)\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass AuffusionNoAdapterPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if hasattr(scheduler.config, \"steps_offset\") and scheduler.config.steps_offset != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if hasattr(scheduler.config, \"clip_sample\") and scheduler.config.clip_sample is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = hasattr(unet.config, \"_diffusers_version\") and version.parse(\n            version.parse(unet.config._diffusers_version).base_version\n        ) < version.parse(\"0.9.0.dev0\")\n        is_unet_sample_size_less_64 = hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- runwayml/stable-diffusion-v1-5\"\n                \" \\n- runwayml/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_tiling()\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, LoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):\n        r\"\"\"Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.\n\n        The suffixes after the scaling factors represent the stages where they are being applied.\n\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values\n        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        if not hasattr(self, \"unet\"):\n            raise ValueError(\"The pipeline must have `unet` for using FreeU.\")\n        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)\n\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism if enabled.\"\"\"\n        self.unet.disable_freeu()\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections\n    def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):\n        \"\"\"\n        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,\n        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.\n\n        <Tip warning={true}>\n\n        This API is 🧪 experimental.\n\n        </Tip>\n\n        Args:\n            unet (`bool`, defaults to `True`): To apply fusion on the UNet.\n            vae (`bool`, defaults to `True`): To apply fusion on the VAE.\n        \"\"\"\n        self.fusing_unet = False\n        self.fusing_vae = False\n\n        if unet:\n            self.fusing_unet = True\n            self.unet.fuse_qkv_projections()\n            self.unet.set_attn_processor(FusedAttnProcessor2_0())\n\n        if vae:\n            if not isinstance(self.vae, AutoencoderKL):\n                raise ValueError(\"`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.\")\n\n            self.fusing_vae = True\n            self.vae.fuse_qkv_projections()\n            self.vae.set_attn_processor(FusedAttnProcessor2_0())\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections\n    def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):\n        \"\"\"Disable QKV projection fusion if enabled.\n\n        <Tip warning={true}>\n\n        This API is 🧪 experimental.\n\n        </Tip>\n\n        Args:\n            unet (`bool`, defaults to `True`): To apply fusion on the UNet.\n            vae (`bool`, defaults to `True`): To apply fusion on the VAE.\n\n        \"\"\"\n        if unet:\n            if not self.fusing_unet:\n                logger.warning(\"The UNet was not initially fused for QKV projections. Doing nothing.\")\n            else:\n                self.unet.unfuse_qkv_projections()\n                self.fusing_unet = False\n\n        if vae:\n            if not self.fusing_vae:\n                logger.warning(\"The VAE was not initially fused for QKV projections. Doing nothing.\")\n            else:\n                self.vae.unfuse_qkv_projections()\n                self.fusing_vae = False\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        # to deal with lora scaling and other possible forward hooks\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                tmp_embeds = negative_prompt_embeds.clone()\n                tmp_embeds[:, 0:1, :] = prompt_embeds\n                prompt_embeds = tmp_embeds\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n        # TODO\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if ip_adapter_image is not None else None\n\n        # 6.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "foleycrafter/pipelines/pipeline_controlnet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor\nfrom foleycrafter.models.auffusion.loaders.ip_adapter import IPAdapterMixin\nfrom foleycrafter.models.auffusion_unet import UNet2DConditionModel\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate\n        >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import torch\n\n        >>> import cv2\n        >>> from PIL import Image\n\n        >>> # download an image\n        >>> image = load_image(\n        ...     \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n        ... )\n        >>> image = np.array(image)\n\n        >>> # get canny image\n        >>> image = cv2.Canny(image, 100, 200)\n        >>> image = image[:, :, None]\n        >>> image = np.concatenate([image, image, image], axis=2)\n        >>> canny_image = Image.fromarray(image)\n\n        >>> # load control net and stable diffusion v1-5\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(\n        ...     \"runwayml/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n\n        >>> # speed up diffusion process with faster scheduler and memory optimization\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> # remove following line if xformers is not installed\n        >>> pipe.enable_xformers_memory_efficient_attention()\n\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # generate image\n        >>> generator = torch.manual_seed(0)\n        >>> image = pipe(\n        ...     \"futuristic-looking woman\", num_inference_steps=20, generator=generator, image=canny_image\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionControlNetPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the `unet` during the denoising process. If you set multiple\n            ControlNets as a list, the outputs from each ControlNet are added together to create one combined\n            additional conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        self.vae.enable_tiling()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_tiling()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, LoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings is not supported at the moment.\")\n            elif len(image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings is not supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    def prepare_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu\n    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):\n        r\"\"\"Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.\n\n        The suffixes after the scaling factors represent the stages where they are being applied.\n\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values\n        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        if not hasattr(self, \"unet\"):\n            raise ValueError(\"The pipeline must have `unet` for using FreeU.\")\n        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism if enabled.\"\"\"\n        self.unet.disable_freeu()\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        controlnet_prompt_embeds: Optional[List[torch.FloatTensor]] = None,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if image_embeds is not None else None\n\n        # 7.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "foleycrafter/utils/converter.py",
    "content": "# Copy from https://github.com/happylittlecat2333/Auffusion/blob/main/converter.py\nimport json\nimport os\nimport random\n\nimport librosa\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data\nfrom librosa.filters import mel as librosa_mel_fn\n\n# from librosa.util import normalize\nfrom scipy.io.wavfile import read\nfrom torch.nn import Conv1d, ConvTranspose1d\nfrom torch.nn.utils import remove_weight_norm, weight_norm\n\n\nMAX_WAV_VALUE = 32768.0\n\n\ndef load_wav(full_path):\n    sampling_rate, data = read(full_path)\n    return data, sampling_rate\n\n\ndef dynamic_range_compression(x, C=1, clip_val=1e-5):\n    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n\n\ndef dynamic_range_decompression(x, C=1):\n    return np.exp(x) / C\n\n\ndef dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n    return torch.log(torch.clamp(x, min=clip_val) * C)\n\n\ndef dynamic_range_decompression_torch(x, C=1):\n    return torch.exp(x) / C\n\n\ndef spectral_normalize_torch(magnitudes):\n    output = dynamic_range_compression_torch(magnitudes)\n    return output\n\n\ndef spectral_de_normalize_torch(magnitudes):\n    output = dynamic_range_decompression_torch(magnitudes)\n    return output\n\n\nmel_basis = {}\nhann_window = {}\n\n\ndef mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n    if torch.min(y) < -1.0:\n        print(\"min value is \", torch.min(y))\n    if torch.max(y) > 1.0:\n        print(\"max value is \", torch.max(y))\n\n    global mel_basis, hann_window\n    if fmax not in mel_basis:\n        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n        mel_basis[str(fmax) + \"_\" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n\n    y = torch.nn.functional.pad(\n        y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode=\"reflect\"\n    )\n    y = y.squeeze(1)\n\n    # complex tensor as default, then use view_as_real for future pytorch compatibility\n    spec = torch.stft(\n        y,\n        n_fft,\n        hop_length=hop_size,\n        win_length=win_size,\n        window=hann_window[str(y.device)],\n        center=center,\n        pad_mode=\"reflect\",\n        normalized=False,\n        onesided=True,\n        return_complex=True,\n    )\n    spec = torch.view_as_real(spec)\n    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))\n\n    spec = torch.matmul(mel_basis[str(fmax) + \"_\" + str(y.device)], spec)\n    spec = spectral_normalize_torch(spec)\n\n    return spec\n\n\ndef spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n    if torch.min(y) < -1.0:\n        print(\"min value is \", torch.min(y))\n    if torch.max(y) > 1.0:\n        print(\"max value is \", torch.max(y))\n\n    global hann_window\n    hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n\n    y = torch.nn.functional.pad(\n        y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode=\"reflect\"\n    )\n    y = y.squeeze(1)\n\n    # complex tensor as default, then use view_as_real for future pytorch compatibility\n    spec = torch.stft(\n        y,\n        n_fft,\n        hop_length=hop_size,\n        win_length=win_size,\n        window=hann_window[str(y.device)],\n        center=center,\n        pad_mode=\"reflect\",\n        normalized=False,\n        onesided=True,\n        return_complex=True,\n    )\n    spec = torch.view_as_real(spec)\n    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))\n\n    return spec\n\n\ndef normalize_spectrogram(\n    spectrogram: torch.Tensor,\n    max_value: float = 200,\n    min_value: float = 1e-5,\n    power: float = 1.0,\n    inverse: bool = False,\n) -> torch.Tensor:\n    # Rescale to 0-1\n    max_value = np.log(max_value)  # 5.298317366548036\n    min_value = np.log(min_value)  # -11.512925464970229\n\n    assert spectrogram.max() <= max_value and spectrogram.min() >= min_value\n\n    data = (spectrogram - min_value) / (max_value - min_value)\n\n    # Invert\n    if inverse:\n        data = 1 - data\n\n    # Apply the power curve\n    data = torch.pow(data, power)\n\n    # 1D -> 3D\n    data = data.repeat(3, 1, 1)\n\n    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner\n    data = torch.flip(data, [1])\n\n    return data\n\n\ndef denormalize_spectrogram(\n    data: torch.Tensor,\n    max_value: float = 200,\n    min_value: float = 1e-5,\n    power: float = 1,\n    inverse: bool = False,\n) -> torch.Tensor:\n    max_value = np.log(max_value)\n    min_value = np.log(min_value)\n\n    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner\n    data = torch.flip(data, [1])\n\n    assert len(data.shape) == 3, \"Expected 3 dimensions, got {}\".format(len(data.shape))\n\n    if data.shape[0] == 1:\n        data = data.repeat(3, 1, 1)\n\n    assert data.shape[0] == 3, \"Expected 3 channels, got {}\".format(data.shape[0])\n    data = data[0]\n\n    # Reverse the power curve\n    data = torch.pow(data, 1 / power)\n\n    # Invert\n    if inverse:\n        data = 1 - data\n\n    # Rescale to max value\n    spectrogram = data * (max_value - min_value) + min_value\n\n    return spectrogram\n\n\ndef get_mel_spectrogram_from_audio(audio, device=\"cpu\"):\n    audio = audio / MAX_WAV_VALUE\n    audio = librosa.util.normalize(audio) * 0.95\n    # print(' >>> normalize done <<< ')\n\n    audio = torch.FloatTensor(audio)\n    audio = audio.unsqueeze(0)\n\n    waveform = audio.to(device)\n    spec = mel_spectrogram(\n        waveform,\n        n_fft=2048,\n        num_mels=256,\n        sampling_rate=16000,\n        hop_size=160,\n        win_size=1024,\n        fmin=0,\n        fmax=8000,\n        center=False,\n    )\n    return audio, spec\n\n\nLRELU_SLOPE = 0.1\nMAX_WAV_VALUE = 32768.0\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n\n\ndef get_config(config_path):\n    config = json.loads(open(config_path).read())\n    config = AttrDict(config)\n    return config\n\n\ndef init_weights(m, mean=0.0, std=0.01):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(mean, std)\n\n\ndef apply_weight_norm(m):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        weight_norm(m)\n\n\ndef get_padding(kernel_size, dilation=1):\n    return int((kernel_size * dilation - dilation) / 2)\n\n\nclass ResBlock1(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):\n        super(ResBlock1, self).__init__()\n        self.h = h\n        self.convs1 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[2],\n                        padding=get_padding(kernel_size, dilation[2]),\n                    )\n                ),\n            ]\n        )\n        self.convs1.apply(init_weights)\n\n        self.convs2 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))\n                ),\n                weight_norm(\n                    Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))\n                ),\n                weight_norm(\n                    Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))\n                ),\n            ]\n        )\n        self.convs2.apply(init_weights)\n\n    def forward(self, x):\n        for c1, c2 in zip(self.convs1, self.convs2):\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c1(xt)\n            xt = F.leaky_relu(xt, LRELU_SLOPE)\n            xt = c2(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)\n\n\nclass ResBlock2(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):\n        super(ResBlock2, self).__init__()\n        self.h = h\n        self.convs = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n            ]\n        )\n        self.convs.apply(init_weights)\n\n    def forward(self, x):\n        for c in self.convs:\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs:\n            remove_weight_norm(l)\n\n\nclass Generator(torch.nn.Module):\n    def __init__(self, h):\n        super(Generator, self).__init__()\n        self.h = h\n        self.num_kernels = len(h.resblock_kernel_sizes)\n        self.num_upsamples = len(h.upsample_rates)\n        self.conv_pre = weight_norm(\n            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)\n        )  # change: 80 --> 512\n        resblock = ResBlock1 if h.resblock == \"1\" else ResBlock2\n\n        self.ups = nn.ModuleList()\n        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):\n            if (k - u) % 2 == 0:\n                self.ups.append(\n                    weight_norm(\n                        ConvTranspose1d(\n                            h.upsample_initial_channel // (2**i),\n                            h.upsample_initial_channel // (2 ** (i + 1)),\n                            k,\n                            u,\n                            padding=(k - u) // 2,\n                        )\n                    )\n                )\n            else:\n                self.ups.append(\n                    weight_norm(\n                        ConvTranspose1d(\n                            h.upsample_initial_channel // (2**i),\n                            h.upsample_initial_channel // (2 ** (i + 1)),\n                            k,\n                            u,\n                            padding=(k - u) // 2 + 1,\n                            output_padding=1,\n                        )\n                    )\n                )\n\n            # self.ups.append(weight_norm(\n            #     ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),\n            #                     k, u, padding=(k-u)//2)))\n\n        self.resblocks = nn.ModuleList()\n        for i in range(len(self.ups)):\n            ch = h.upsample_initial_channel // (2 ** (i + 1))\n            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):\n                self.resblocks.append(resblock(h, ch, k, d))\n\n        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))\n        self.ups.apply(init_weights)\n        self.conv_post.apply(init_weights)\n\n    def forward(self, x):\n        x = self.conv_pre(x)\n        for i in range(self.num_upsamples):\n            x = F.leaky_relu(x, LRELU_SLOPE)\n            x = self.ups[i](x)\n            xs = None\n            for j in range(self.num_kernels):\n                if xs is None:\n                    xs = self.resblocks[i * self.num_kernels + j](x)\n                else:\n                    xs += self.resblocks[i * self.num_kernels + j](x)\n            x = xs / self.num_kernels\n        x = F.leaky_relu(x)\n        x = self.conv_post(x)\n        x = torch.tanh(x)\n\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.ups:\n            remove_weight_norm(l)\n        for l in self.resblocks:\n            l.remove_weight_norm()\n        remove_weight_norm(self.conv_pre)\n        remove_weight_norm(self.conv_post)\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):\n        if subfolder is not None:\n            pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)\n        config_path = os.path.join(pretrained_model_name_or_path, \"config.json\")\n        ckpt_path = os.path.join(pretrained_model_name_or_path, \"vocoder.pt\")\n\n        config = get_config(config_path)\n        vocoder = cls(config)\n\n        state_dict_g = torch.load(ckpt_path)\n        vocoder.load_state_dict(state_dict_g[\"generator\"])\n        vocoder.eval()\n        vocoder.remove_weight_norm()\n        return vocoder\n\n    @torch.no_grad()\n    def inference(self, mels, lengths=None):\n        self.eval()\n        with torch.no_grad():\n            wavs = self(mels).squeeze(1)\n\n        wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype(\"int16\")\n\n        if lengths is not None:\n            wavs = wavs[:, :lengths]\n\n        return wavs\n\n\ndef normalize(images):\n    \"\"\"\n    Normalize an image array to [-1,1].\n    \"\"\"\n    if images.min() >= 0:\n        return 2.0 * images - 1.0\n    else:\n        return images\n\n\ndef pad_spec(spec, spec_length, pad_value=0, random_crop=True):  # spec: [3, mel_dim, spec_len]\n    assert spec_length % 8 == 0, \"spec_length must be divisible by 8\"\n    if spec.shape[-1] < spec_length:\n        # pad spec to spec_length\n        spec = F.pad(spec, (0, spec_length - spec.shape[-1]), value=pad_value)\n    else:\n        # random crop\n        if random_crop:\n            start = random.randint(0, spec.shape[-1] - spec_length)\n            spec = spec[:, :, start : start + spec_length]\n        else:\n            spec = spec[:, :, :spec_length]\n    return spec\n"
  },
  {
    "path": "foleycrafter/utils/spec_to_mel.py",
    "content": "import librosa.util as librosa_util\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchaudio\nfrom librosa.filters import mel as librosa_mel_fn\nfrom librosa.util import pad_center, tiny\nfrom scipy.signal import get_window\n\n\n# spectrogram to mel\n\n\nclass STFT(torch.nn.Module):\n    \"\"\"adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft\"\"\"\n\n    def __init__(self, filter_length, hop_length, win_length, window=\"hann\"):\n        super(STFT, self).__init__()\n        self.filter_length = filter_length\n        self.hop_length = hop_length\n        self.win_length = win_length\n        self.window = window\n        self.forward_transform = None\n        scale = self.filter_length / self.hop_length\n        fourier_basis = np.fft.fft(np.eye(self.filter_length))\n\n        cutoff = int((self.filter_length / 2 + 1))\n        fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])\n\n        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])\n        inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])\n\n        if window is not None:\n            assert filter_length >= win_length\n            # get window and zero center pad it to filter_length\n            fft_window = get_window(window, win_length, fftbins=True)\n            fft_window = pad_center(fft_window, filter_length)\n            fft_window = torch.from_numpy(fft_window).float()\n\n            # window the bases\n            forward_basis *= fft_window\n            inverse_basis *= fft_window\n\n        self.register_buffer(\"forward_basis\", forward_basis.float())\n        self.register_buffer(\"inverse_basis\", inverse_basis.float())\n\n    def transform(self, input_data):\n        num_batches = input_data.size(0)\n        num_samples = input_data.size(1)\n\n        self.num_samples = num_samples\n\n        # similar to librosa, reflect-pad the input\n        input_data = input_data.view(num_batches, 1, num_samples)\n        input_data = F.pad(\n            input_data.unsqueeze(1),\n            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),\n            mode=\"reflect\",\n        )\n        input_data = input_data.squeeze(1)\n\n        forward_transform = F.conv1d(\n            input_data,\n            torch.autograd.Variable(self.forward_basis, requires_grad=False),\n            stride=self.hop_length,\n            padding=0,\n        ).cpu()\n\n        cutoff = int((self.filter_length / 2) + 1)\n        real_part = forward_transform[:, :cutoff, :]\n        imag_part = forward_transform[:, cutoff:, :]\n\n        magnitude = torch.sqrt(real_part**2 + imag_part**2)\n        phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))\n\n        return magnitude, phase\n\n    def inverse(self, magnitude, phase):\n        recombine_magnitude_phase = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1)\n\n        inverse_transform = F.conv_transpose1d(\n            recombine_magnitude_phase,\n            torch.autograd.Variable(self.inverse_basis, requires_grad=False),\n            stride=self.hop_length,\n            padding=0,\n        )\n\n        if self.window is not None:\n            window_sum = window_sumsquare(\n                self.window,\n                magnitude.size(-1),\n                hop_length=self.hop_length,\n                win_length=self.win_length,\n                n_fft=self.filter_length,\n                dtype=np.float32,\n            )\n            # remove modulation effects\n            approx_nonzero_indices = torch.from_numpy(np.where(window_sum > tiny(window_sum))[0])\n            window_sum = torch.autograd.Variable(torch.from_numpy(window_sum), requires_grad=False)\n            window_sum = window_sum\n            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]\n\n            # scale by hop ratio\n            inverse_transform *= float(self.filter_length) / self.hop_length\n\n        inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]\n        inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]\n\n        return inverse_transform\n\n    def forward(self, input_data):\n        self.magnitude, self.phase = self.transform(input_data)\n        reconstruction = self.inverse(self.magnitude, self.phase)\n        return reconstruction\n\n\ndef window_sumsquare(\n    window,\n    n_frames,\n    hop_length,\n    win_length,\n    n_fft,\n    dtype=np.float32,\n    norm=None,\n):\n    \"\"\"\n    # from librosa 0.6\n    Compute the sum-square envelope of a window function at a given hop length.\n\n    This is used to estimate modulation effects induced by windowing\n    observations in short-time fourier transforms.\n\n    Parameters\n    ----------\n    window : string, tuple, number, callable, or list-like\n        Window specification, as in `get_window`\n\n    n_frames : int > 0\n        The number of analysis frames\n\n    hop_length : int > 0\n        The number of samples to advance between frames\n\n    win_length : [optional]\n        The length of the window function.  By default, this matches `n_fft`.\n\n    n_fft : int > 0\n        The length of each analysis frame.\n\n    dtype : np.dtype\n        The data type of the output\n\n    Returns\n    -------\n    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`\n        The sum-squared envelope of the window function\n    \"\"\"\n    if win_length is None:\n        win_length = n_fft\n\n    n = n_fft + hop_length * (n_frames - 1)\n    x = np.zeros(n, dtype=dtype)\n\n    # Compute the squared window at the desired length\n    win_sq = get_window(window, win_length, fftbins=True)\n    win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2\n    win_sq = librosa_util.pad_center(win_sq, n_fft)\n\n    # Fill the envelope\n    for i in range(n_frames):\n        sample = i * hop_length\n        x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]\n    return x\n\n\ndef griffin_lim(magnitudes, stft_fn, n_iters=30):\n    \"\"\"\n    PARAMS\n    ------\n    magnitudes: spectrogram magnitudes\n    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods\n    \"\"\"\n\n    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))\n    angles = angles.astype(np.float32)\n    angles = torch.autograd.Variable(torch.from_numpy(angles))\n    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)\n\n    for i in range(n_iters):\n        _, angles = stft_fn.transform(signal)\n        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)\n    return signal\n\n\ndef dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):\n    \"\"\"\n    PARAMS\n    ------\n    C: compression factor\n    \"\"\"\n    return normalize_fun(torch.clamp(x, min=clip_val) * C)\n\n\ndef dynamic_range_decompression(x, C=1):\n    \"\"\"\n    PARAMS\n    ------\n    C: compression factor used to compress\n    \"\"\"\n    return torch.exp(x) / C\n\n\nclass TacotronSTFT(torch.nn.Module):\n    def __init__(\n        self,\n        filter_length,\n        hop_length,\n        win_length,\n        n_mel_channels,\n        sampling_rate,\n        mel_fmin,\n        mel_fmax,\n    ):\n        super(TacotronSTFT, self).__init__()\n        self.n_mel_channels = n_mel_channels\n        self.sampling_rate = sampling_rate\n        self.stft_fn = STFT(filter_length, hop_length, win_length)\n        mel_basis = librosa_mel_fn(sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)\n        mel_basis = torch.from_numpy(mel_basis).float()\n        self.register_buffer(\"mel_basis\", mel_basis)\n\n    def spectral_normalize(self, magnitudes, normalize_fun):\n        output = dynamic_range_compression(magnitudes, normalize_fun)\n        return output\n\n    def spectral_de_normalize(self, magnitudes):\n        output = dynamic_range_decompression(magnitudes)\n        return output\n\n    def mel_spectrogram(self, y, normalize_fun=torch.log):\n        \"\"\"Computes mel-spectrograms from a batch of waves\n        PARAMS\n        ------\n        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]\n\n        RETURNS\n        -------\n        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)\n        \"\"\"\n        assert torch.min(y.data) >= -1, torch.min(y.data)\n        assert torch.max(y.data) <= 1, torch.max(y.data)\n\n        magnitudes, phases = self.stft_fn.transform(y)\n        magnitudes = magnitudes.data\n        mel_output = torch.matmul(self.mel_basis, magnitudes)\n        mel_output = self.spectral_normalize(mel_output, normalize_fun)\n        energy = torch.norm(magnitudes, dim=1)\n\n        log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)\n\n        return mel_output, log_magnitudes, energy\n\n\ndef pad_wav(waveform, segment_length):\n    waveform_length = waveform.shape[-1]\n    assert waveform_length > 100, \"Waveform is too short, %s\" % waveform_length\n    if segment_length is None or waveform_length == segment_length:\n        return waveform\n    elif waveform_length > segment_length:\n        return waveform[:, :segment_length]\n    elif waveform_length < segment_length:\n        temp_wav = np.zeros((1, segment_length))\n        temp_wav[:, :waveform_length] = waveform\n    return temp_wav\n\n\ndef normalize_wav(waveform):\n    waveform = waveform - np.mean(waveform)\n    waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)\n    return waveform * 0.5\n\n\ndef _pad_spec(fbank, target_length=1024):\n    n_frames = fbank.shape[0]\n    p = target_length - n_frames\n    # cut and pad\n    if p > 0:\n        m = torch.nn.ZeroPad2d((0, 0, 0, p))\n        fbank = m(fbank)\n    elif p < 0:\n        fbank = fbank[0:target_length, :]\n\n    if fbank.size(-1) % 2 != 0:\n        fbank = fbank[..., :-1]\n\n    return fbank\n\n\ndef get_mel_from_wav(audio, _stft):\n    audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)\n    audio = torch.autograd.Variable(audio, requires_grad=False)\n    melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)\n    melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)\n    log_magnitudes_stft = torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)\n    energy = torch.squeeze(energy, 0).numpy().astype(np.float32)\n    return melspec, log_magnitudes_stft, energy\n\n\ndef read_wav_file_io(bytes):\n    # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower\n    waveform, sr = torchaudio.load(bytes, format=\"mp4\")  # Faster!!!\n    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)\n    # waveform = waveform.numpy()[0, ...]\n    # waveform = normalize_wav(waveform)\n    # waveform = waveform[None, ...]\n\n    # waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)\n    # waveform = 0.5 * waveform\n\n    return waveform\n\n\ndef load_audio(bytes, sample_rate=16000):\n    waveform, sr = torchaudio.load(bytes, format=\"mp4\")\n    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)\n    return waveform\n\n\ndef read_wav_file(filename):\n    # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower\n    waveform, sr = torchaudio.load(filename)  # Faster!!!\n    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)\n    waveform = waveform.numpy()[0, ...]\n    waveform = normalize_wav(waveform)\n    waveform = waveform[None, ...]\n\n    waveform = waveform / np.max(np.abs(waveform))\n    waveform = 0.5 * waveform\n\n    return waveform\n\n\ndef norm_wav_tensor(waveform: torch.FloatTensor):\n    waveform = waveform.numpy()[0, ...]\n    waveform = normalize_wav(waveform)\n    waveform = waveform[None, ...]\n    waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)\n    waveform = 0.5 * waveform\n    return waveform\n\n\ndef wav_to_fbank(filename, target_length=1024, fn_STFT=None):\n    if fn_STFT is None:\n        fn_STFT = TacotronSTFT(\n            1024,  # filter_length\n            160,  # hop_length\n            1024,  # win_length\n            64,  # n_mel\n            16000,  # sample_rate\n            0,  # fmin\n            8000,  # fmax\n        )\n\n    # mixup\n    waveform = read_wav_file(filename, target_length * 160)  # hop size is 160\n\n    waveform = waveform[0, ...]\n    waveform = torch.FloatTensor(waveform)\n\n    fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)\n\n    fbank = torch.FloatTensor(fbank.T)\n    log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)\n\n    fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(log_magnitudes_stft, target_length)\n\n    return fbank, log_magnitudes_stft, waveform\n\n\ndef wav_tensor_to_fbank(waveform, target_length=512, fn_STFT=None):\n    if fn_STFT is None:\n        fn_STFT = TacotronSTFT(\n            1024,  # filter_length\n            160,  # hop_length\n            1024,  # win_length\n            256,  # n_mel\n            16000,  # sample_rate\n            0,  # fmin\n            8000,  # fmax\n        )  # In practice used\n\n    fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)\n\n    fbank = torch.FloatTensor(fbank.T)\n    log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)\n\n    fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(log_magnitudes_stft, target_length)\n\n    return fbank\n"
  },
  {
    "path": "foleycrafter/utils/util.py",
    "content": "import glob\nimport io\nimport os\nimport os.path as osp\nimport random\nimport typing as T\nimport warnings\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom typing import Union\n\nimport decord\nimport imageio\nimport numpy as np\nimport pydub\nimport soundfile as sf\nimport torch\nimport torch.distributed as dist\nimport torchaudio\nimport torchvision\nimport torchvision.transforms as transforms\nfrom einops import rearrange\nfrom moviepy.editor import AudioFileClip, ImageSequenceClip, VideoFileClip\nfrom PIL import Image, ImageOps\nfrom scipy.io import wavfile\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import ControlNetModel\nfrom diffusers.models import AutoencoderKL\nfrom diffusers.schedulers import DDIMScheduler, PNDMScheduler\nfrom foleycrafter.models.auffusion_unet import UNet2DConditionModel as af_UNet2DConditionModel\nfrom foleycrafter.pipelines.pipeline_controlnet import StableDiffusionControlNetPipeline\n\n\ndef zero_rank_print(s):\n    if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0):\n        print(\"### \" + s, flush=True)\n\n\ndef build_foleycrafter(\n    pretrained_model_name_or_path: str = \"auffusion/auffusion-full-no-adapter\",\n) -> StableDiffusionControlNetPipeline:\n    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = af_UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"unet\")\n    scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n\n    controlnet = ControlNetModel.from_unet(unet, conditioning_channels=1)\n\n    pipe = StableDiffusionControlNetPipeline(\n        vae=vae,\n        controlnet=controlnet,\n        unet=unet,\n        scheduler=scheduler,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        feature_extractor=None,\n        safety_checker=None,\n        requires_safety_checker=False,\n    )\n\n    return pipe\n\n\ndef save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):\n    if len(videos.shape) == 4:\n        videos = videos.unsqueeze(0)\n    videos = rearrange(videos, \"b c t h w -> t b c h w\")\n    outputs = []\n    for x in videos:\n        x = torchvision.utils.make_grid(x, nrow=n_rows)\n        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)\n        if rescale:\n            x = (x + 1.0) / 2.0  # -1,1 -> 0,1\n        x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8)\n        outputs.append(x)\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    imageio.mimsave(path, outputs, fps=fps)\n\n\ndef save_videos_from_pil_list(videos: list, path: str, fps=7):\n    for i in range(len(videos)):\n        videos[i] = ImageOps.scale(videos[i], 255)\n\n    imageio.mimwrite(path, videos, fps=fps)\n\n\ndef seed_everything(seed: int) -> None:\n    r\"\"\"Sets the seed for generating random numbers in :pytorch:`PyTorch`,\n    :obj:`numpy` and :python:`Python`.\n\n    Args:\n        seed (int): The desired seed.\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef get_video_frames(video: np.ndarray, num_frames: int = 200):\n    video_length = video.shape[0]\n    video_idx = np.linspace(0, video_length - 1, num_frames, dtype=int)\n    video = video[video_idx, ...]\n    return video\n\n\ndef random_audio_video_clip(\n    audio: np.ndarray, video: np.ndarray, fps: float, sample_rate: int = 16000, duration: int = 5, num_frames: int = 20\n):\n    \"\"\"\n    Random sample video clips with duration\n    \"\"\"\n    video_length = video.shape[0]\n    audio_length = audio.shape[-1]\n    av_duration = int(video_length / fps)\n    assert av_duration >= duration, f\"video duration {av_duration} is less than {duration}\"\n\n    # random sample start time\n    start_time = random.uniform(0, av_duration - duration)\n    end_time = start_time + duration\n\n    start_idx, end_idx = start_time / av_duration, end_time / av_duration\n\n    video_start_frame, video_end_frame = video_length * start_idx, video_length * end_idx\n    audio_start_frame, audio_end_frame = audio_length * start_idx, audio_length * end_idx\n\n    # print(f\"time_idx : {start_time}:{end_time}\")\n    # print(f\"video_idx: {video_start_frame}:{video_end_frame}\")\n    # print(f\"audio_idx: {audio_start_frame}:{audio_end_frame}\")\n\n    audio_idx = np.linspace(audio_start_frame, audio_end_frame, sample_rate * duration, dtype=int)\n    video_idx = np.linspace(video_start_frame, video_end_frame, num_frames, dtype=int)\n\n    audio = audio[..., audio_idx]\n    video = video[video_idx, ...]\n\n    return audio, video\n\n\ndef get_full_indices(reader: Union[decord.VideoReader, decord.AudioReader]) -> np.ndarray:\n    if isinstance(reader, decord.VideoReader):\n        return np.linspace(0, len(reader) - 1, len(reader), dtype=int)\n    elif isinstance(reader, decord.AudioReader):\n        return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int)\n\n\ndef get_frames(video_path: str, onset_list, frame_nums=1024):\n    video = decord.VideoReader(video_path)\n    video_frame = len(video)\n\n    frames_list = []\n    for start, end in onset_list:\n        video_start = int(start / frame_nums * video_frame)\n        video_end = int(end / frame_nums * video_frame)\n\n        frames_list.extend(range(video_start, video_end))\n    frames = video.get_batch(frames_list).asnumpy()\n    return frames\n\n\ndef get_frames_in_video(video_path: str, onset_list, frame_nums=1024, audio_length_in_s=10):\n    # this function consider the video length\n    video = decord.VideoReader(video_path)\n    video_frame = len(video)\n    duration = video_frame / video.get_avg_fps()\n    frames_list = []\n    video_onset_list = []\n    for start, end in onset_list:\n        if int(start / frame_nums * duration) >= audio_length_in_s:\n            continue\n        video_start = int(start / audio_length_in_s * duration / frame_nums * video_frame)\n        if video_start >= video_frame:\n            continue\n        video_end = int(end / audio_length_in_s * duration / frame_nums * video_frame)\n        video_onset_list.append([int(start / audio_length_in_s * duration), int(end / audio_length_in_s * duration)])\n        frames_list.extend(range(video_start, video_end))\n    frames = video.get_batch(frames_list).asnumpy()\n    return frames, video_onset_list\n\n\ndef save_multimodal(video, audio, output_path, audio_fps: int = 16000, video_fps: int = 8, remove_audio: bool = True):\n    imgs = list(video)\n    # if audio.shape[0] == 1 or audio.shape[0] == 2:\n    #     audio = audio.T #[len, channel]\n    # audio = np.repeat(audio, 2, axis=1)\n    output_dir = osp.dirname(output_path)\n    try:\n        wavfile.write(osp.join(output_dir, \"audio.wav\"), audio_fps, audio)\n    except Exception:\n        sf.write(osp.join(output_dir, \"audio.wav\"), audio, audio_fps)\n    audio_clip = AudioFileClip(osp.join(output_dir, \"audio.wav\"))\n    # audio_clip = AudioArrayClip(audio, fps=audio_fps)\n    video_clip = ImageSequenceClip(imgs, fps=video_fps)\n    video_clip = video_clip.set_audio(audio_clip)\n    video_clip.write_videofile(output_path, video_fps, audio=True, audio_fps=audio_fps)\n    if remove_audio:\n        os.remove(osp.join(output_dir, \"audio.wav\"))\n    return\n\n\ndef save_multimodal_by_frame(video, audio, output_path, audio_fps: int = 16000):\n    imgs = list(video)\n    # if audio.shape[0] == 1 or audio.shape[0] == 2:\n    #     audio = audio.T #[len, channel]\n    # audio = np.repeat(audio, 2, axis=1)\n    # output_dir = osp.dirname(output_path)\n    output_dir = output_path\n    wavfile.write(osp.join(output_dir, \"audio.wav\"), audio_fps, audio)\n    # audio_clip = AudioFileClip(osp.join(output_dir, \"audio.wav\"))\n    # audio_clip = AudioArrayClip(audio, fps=audio_fps)\n    os.makedirs(osp.join(output_dir, \"frames\"), exist_ok=True)\n    for num, img in enumerate(imgs):\n        if isinstance(img, np.ndarray):\n            img = Image.fromarray(img.astype(np.uint8))\n        img.save(osp.join(output_dir, \"frames\", f\"{num}.jpg\"))\n    return\n\n\ndef sanity_check(data: dict, save_path: str = \"sanity_check\", batch_size: int = 4, sample_rate: int = 16000):\n    video_path = osp.join(save_path, \"video\")\n    audio_path = osp.join(save_path, \"audio\")\n    av_path = osp.join(save_path, \"av\")\n\n    video, audio, text = data[\"pixel_values\"], data[\"audio\"], data[\"text\"]\n    video = (video / 2 + 0.5).clamp(0, 1)\n\n    zero_rank_print(f\"Saving {text} audio: {audio[0].shape} video: {video[0].shape}\")\n\n    for bsz in range(batch_size):\n        os.makedirs(video_path, exist_ok=True)\n        os.makedirs(audio_path, exist_ok=True)\n        os.makedirs(av_path, exist_ok=True)\n        # save_videos_grid(video[bsz:bsz+1,...], f\"{osp.join(video_path, str(bsz) + '.mp4')}\")\n        bsz_audio = audio[bsz, ...].permute(1, 0).cpu().numpy()\n        bsz_video = video_tensor_to_np(video[bsz, ...])\n        sf.write(f\"{osp.join(audio_path, str(bsz) + '.wav')}\", bsz_audio, sample_rate)\n        save_multimodal(bsz_video, bsz_audio, osp.join(av_path, str(bsz) + \".mp4\"))\n\n\ndef video_tensor_to_np(video: torch.Tensor, rescale: bool = True, scale: bool = False):\n    if scale:\n        video = (video / 2 + 0.5).clamp(0, 1)\n    # c f h w -> f h w c\n    if video.shape[0] == 3:\n        video = video.permute(1, 2, 3, 0).detach().cpu().numpy()\n    elif video.shape[1] == 3:\n        video = video.permute(0, 2, 3, 1).detach().cpu().numpy()\n    if rescale:\n        video = video * 255\n    return video\n\n\ndef composite_audio_video(video: str, audio: str, path: str, video_fps: int = 7, audio_sample_rate: int = 16000):\n    video = decord.VideoReader(video)\n    audio = decord.AudioReader(audio, sample_rate=audio_sample_rate)\n    audio = audio.get_batch(get_full_indices(audio)).asnumpy()\n    video = video.get_batch(get_full_indices(video)).asnumpy()\n    save_multimodal(video, audio, path, audio_fps=audio_sample_rate, video_fps=video_fps)\n    return\n\n\n# for video pipeline\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef resize_with_antialiasing(input, size, interpolation=\"bicubic\", align_corners=True):\n    h, w = input.shape[-2:]\n    factors = (h / size[0], w / size[1])\n\n    # First, we have to determine sigma\n    # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171\n    sigmas = (\n        max((factors[0] - 1.0) / 2.0, 0.001),\n        max((factors[1] - 1.0) / 2.0, 0.001),\n    )\n\n    # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma\n    # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206\n    # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now\n    ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))\n\n    # Make sure it is odd\n    if (ks[0] % 2) == 0:\n        ks = ks[0] + 1, ks[1]\n\n    if (ks[1] % 2) == 0:\n        ks = ks[0], ks[1] + 1\n\n    input = _gaussian_blur2d(input, ks, sigmas)\n\n    output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)\n    return output\n\n\ndef _gaussian_blur2d(input, kernel_size, sigma):\n    if isinstance(sigma, tuple):\n        sigma = torch.tensor([sigma], dtype=input.dtype)\n    else:\n        sigma = sigma.to(dtype=input.dtype)\n\n    ky, kx = int(kernel_size[0]), int(kernel_size[1])\n    bs = sigma.shape[0]\n    kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))\n    kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))\n    out_x = _filter2d(input, kernel_x[..., None, :])\n    out = _filter2d(out_x, kernel_y[..., None])\n\n    return out\n\n\ndef _filter2d(input, kernel):\n    # prepare kernel\n    b, c, h, w = input.shape\n    tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)\n\n    tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)\n\n    height, width = tmp_kernel.shape[-2:]\n\n    padding_shape: list[int] = _compute_padding([height, width])\n    input = torch.nn.functional.pad(input, padding_shape, mode=\"reflect\")\n\n    # kernel and input tensor reshape to align element-wise or batch-wise params\n    tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)\n    input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))\n\n    # convolve the tensor with the kernel.\n    output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)\n\n    out = output.view(b, c, h, w)\n    return out\n\n\ndef _gaussian(window_size: int, sigma):\n    if isinstance(sigma, float):\n        sigma = torch.tensor([[sigma]])\n\n    batch_size = sigma.shape[0]\n\n    x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)\n\n    if window_size % 2 == 0:\n        x = x + 0.5\n\n    gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))\n\n    return gauss / gauss.sum(-1, keepdim=True)\n\n\ndef _compute_padding(kernel_size):\n    \"\"\"Compute padding tuple.\"\"\"\n    # 4 or 6 ints:  (padding_left, padding_right,padding_top,padding_bottom)\n    # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad\n    if len(kernel_size) < 2:\n        raise AssertionError(kernel_size)\n    computed = [k - 1 for k in kernel_size]\n\n    # for even kernels we need to do asymmetric padding :(\n    out_padding = 2 * len(kernel_size) * [0]\n\n    for i in range(len(kernel_size)):\n        computed_tmp = computed[-(i + 1)]\n\n        pad_front = computed_tmp // 2\n        pad_rear = computed_tmp - pad_front\n\n        out_padding[2 * i + 0] = pad_front\n        out_padding[2 * i + 1] = pad_rear\n\n    return out_padding\n\n\ndef print_gpu_memory_usage(info: str, cuda_id: int = 0):\n    print(f\">>> {info} <<<\")\n    reserved = torch.cuda.memory_reserved(cuda_id) / 1024**3\n    used = torch.cuda.memory_allocated(cuda_id) / 1024**3\n\n    print(\"total: \", reserved, \"G\")\n    print(\"used: \", used, \"G\")\n    print(\"available: \", reserved - used, \"G\")\n\n\n# use for dsp mel2spec\n@dataclass(frozen=True)\nclass SpectrogramParams:\n    \"\"\"\n    Parameters for the conversion from audio to spectrograms to images and back.\n\n    Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored\n    within spectrogram images.\n\n    To understand what these parameters do and to customize them, read `spectrogram_converter.py`\n    and the linked torchaudio documentation.\n    \"\"\"\n\n    # Whether the audio is stereo or mono\n    stereo: bool = False\n\n    # FFT parameters\n    sample_rate: int = 44100\n    step_size_ms: int = 10\n    window_duration_ms: int = 100\n    padded_duration_ms: int = 400\n\n    # Mel scale parameters\n    num_frequencies: int = 200\n    # TODO(hayk): Set these to [20, 20000] for newer models\n    min_frequency: int = 0\n    max_frequency: int = 10000\n    mel_scale_norm: T.Optional[str] = None\n    mel_scale_type: str = \"htk\"\n    max_mel_iters: int = 200\n\n    # Griffin Lim parameters\n    num_griffin_lim_iters: int = 32\n\n    # Image parameterization\n    power_for_image: float = 0.25\n\n    class ExifTags(Enum):\n        \"\"\"\n        Custom EXIF tags for the spectrogram image.\n        \"\"\"\n\n        SAMPLE_RATE = 11000\n        STEREO = 11005\n        STEP_SIZE_MS = 11010\n        WINDOW_DURATION_MS = 11020\n        PADDED_DURATION_MS = 11030\n\n        NUM_FREQUENCIES = 11040\n        MIN_FREQUENCY = 11050\n        MAX_FREQUENCY = 11060\n\n        POWER_FOR_IMAGE = 11070\n        MAX_VALUE = 11080\n\n    @property\n    def n_fft(self) -> int:\n        \"\"\"\n        The number of samples in each STFT window, with padding.\n        \"\"\"\n        return int(self.padded_duration_ms / 1000.0 * self.sample_rate)\n\n    @property\n    def win_length(self) -> int:\n        \"\"\"\n        The number of samples in each STFT window.\n        \"\"\"\n        return int(self.window_duration_ms / 1000.0 * self.sample_rate)\n\n    @property\n    def hop_length(self) -> int:\n        \"\"\"\n        The number of samples between each STFT window.\n        \"\"\"\n        return int(self.step_size_ms / 1000.0 * self.sample_rate)\n\n    def to_exif(self) -> T.Dict[int, T.Any]:\n        \"\"\"\n        Return a dictionary of EXIF tags for the current values.\n        \"\"\"\n        return {\n            self.ExifTags.SAMPLE_RATE.value: self.sample_rate,\n            self.ExifTags.STEREO.value: self.stereo,\n            self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,\n            self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,\n            self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,\n            self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,\n            self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,\n            self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,\n            self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),\n        }\n\n\nclass SpectrogramImageConverter:\n    \"\"\"\n    Convert between spectrogram images and audio segments.\n\n    This is a wrapper around SpectrogramConverter that additionally converts from spectrograms\n    to images and back. The real audio processing lives in SpectrogramConverter.\n    \"\"\"\n\n    def __init__(self, params: SpectrogramParams, device: str = \"cuda\"):\n        self.p = params\n        self.device = device\n        self.converter = SpectrogramConverter(params=params, device=device)\n\n    def spectrogram_image_from_audio(\n        self,\n        segment: pydub.AudioSegment,\n    ) -> Image.Image:\n        \"\"\"\n        Compute a spectrogram image from an audio segment.\n\n        Args:\n            segment: Audio segment to convert\n\n        Returns:\n            Spectrogram image (in pillow format)\n        \"\"\"\n        assert int(segment.frame_rate) == self.p.sample_rate, \"Sample rate mismatch\"\n\n        if self.p.stereo:\n            if segment.channels == 1:\n                print(\"WARNING: Mono audio but stereo=True, cloning channel\")\n                segment = segment.set_channels(2)\n            elif segment.channels > 2:\n                print(\"WARNING: Multi channel audio, reducing to stereo\")\n                segment = segment.set_channels(2)\n        else:\n            if segment.channels > 1:\n                print(\"WARNING: Stereo audio but stereo=False, setting to mono\")\n                segment = segment.set_channels(1)\n\n        spectrogram = self.converter.spectrogram_from_audio(segment)\n\n        image = image_from_spectrogram(\n            spectrogram,\n            power=self.p.power_for_image,\n        )\n\n        # Store conversion params in exif metadata of the image\n        exif_data = self.p.to_exif()\n        exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))\n        exif = image.getexif()\n        exif.update(exif_data.items())\n\n        return image\n\n    def audio_from_spectrogram_image(\n        self,\n        image: Image.Image,\n        apply_filters: bool = True,\n        max_value: float = 30e6,\n    ) -> pydub.AudioSegment:\n        \"\"\"\n        Reconstruct an audio segment from a spectrogram image.\n\n        Args:\n            image: Spectrogram image (in pillow format)\n            apply_filters: Apply post-processing to improve the reconstructed audio\n            max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.\n        \"\"\"\n        spectrogram = spectrogram_from_image(\n            image,\n            max_value=max_value,\n            power=self.p.power_for_image,\n            stereo=self.p.stereo,\n        )\n\n        segment = self.converter.audio_from_spectrogram(\n            spectrogram,\n            apply_filters=apply_filters,\n        )\n\n        return segment\n\n\ndef image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:\n    \"\"\"\n    Compute a spectrogram image from a spectrogram magnitude array.\n\n    This is the inverse of spectrogram_from_image, except for discretization error from\n    quantizing to uint8.\n\n    Args:\n        spectrogram: (channels, frequency, time)\n        power: A power curve to apply to the spectrogram to preserve contrast\n\n    Returns:\n        image: (frequency, time, channels)\n    \"\"\"\n    # Rescale to 0-1\n    max_value = np.max(spectrogram)\n    data = spectrogram / max_value\n\n    # Apply the power curve\n    data = np.power(data, power)\n\n    # Rescale to 0-255\n    data = data * 255\n\n    # Invert\n    data = 255 - data\n\n    # Convert to uint8\n    data = data.astype(np.uint8)\n\n    # Munge channels into a PIL image\n    if data.shape[0] == 1:\n        # TODO(hayk): Do we want to write single channel to disk instead?\n        image = Image.fromarray(data[0], mode=\"L\").convert(\"RGB\")\n    elif data.shape[0] == 2:\n        data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)\n        image = Image.fromarray(data, mode=\"RGB\")\n    else:\n        raise NotImplementedError(f\"Unsupported number of channels: {data.shape[0]}\")\n\n    # Flip Y\n    image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)\n\n    return image\n\n\ndef spectrogram_from_image(\n    image: Image.Image,\n    power: float = 0.25,\n    stereo: bool = False,\n    max_value: float = 30e6,\n) -> np.ndarray:\n    \"\"\"\n    Compute a spectrogram magnitude array from a spectrogram image.\n\n    This is the inverse of image_from_spectrogram, except for discretization error from\n    quantizing to uint8.\n\n    Args:\n        image: (frequency, time, channels)\n        power: The power curve applied to the spectrogram\n        stereo: Whether the spectrogram encodes stereo data\n        max_value: The max value of the original spectrogram. In practice doesn't matter.\n\n    Returns:\n        spectrogram: (channels, frequency, time)\n    \"\"\"\n    # Convert to RGB if single channel\n    if image.mode in (\"P\", \"L\"):\n        image = image.convert(\"RGB\")\n\n    # Flip Y\n    image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)\n\n    # Munge channels into a numpy array of (channels, frequency, time)\n    data = np.array(image).transpose(2, 0, 1)\n    if stereo:\n        # Take the G and B channels as done in image_from_spectrogram\n        data = data[[1, 2], :, :]\n    else:\n        data = data[0:1, :, :]\n\n    # Convert to floats\n    data = data.astype(np.float32)\n\n    # Invert\n    data = 255 - data\n\n    # Rescale to 0-1\n    data = data / 255\n\n    # Reverse the power curve\n    data = np.power(data, 1 / power)\n\n    # Rescale to max value\n    data = data * max_value\n\n    return data\n\n\nclass SpectrogramConverter:\n    \"\"\"\n    Convert between audio segments and spectrogram tensors using torchaudio.\n\n    In this class a \"spectrogram\" is defined as a (batch, time, frequency) tensor with float values\n    that represent the amplitude of the frequency at that time bucket (in the frequency domain).\n    Frequencies are given in the perceptul Mel scale defined by the params. A more specific term\n    used in some functions is \"mel amplitudes\".\n\n    The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only\n    returns the amplitude, because the phase is chaotic and hard to learn. The function\n    `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which\n    approximates the phase information using the Griffin-Lim algorithm.\n\n    Each channel in the audio is treated independently, and the spectrogram has a batch dimension\n    equal to the number of channels in the input audio segment.\n\n    Both the Griffin Lim algorithm and the Mel scaling process are lossy.\n\n    For more information, see https://pytorch.org/audio/stable/transforms.html\n    \"\"\"\n\n    def __init__(self, params: SpectrogramParams, device: str = \"cuda\"):\n        self.p = params\n\n        self.device = check_device(device)\n\n        if device.lower().startswith(\"mps\"):\n            warnings.warn(\n                \"WARNING: MPS does not support audio operations, falling back to CPU for them\",\n                stacklevel=2,\n            )\n            self.device = \"cpu\"\n\n        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html\n        self.spectrogram_func = torchaudio.transforms.Spectrogram(\n            n_fft=params.n_fft,\n            hop_length=params.hop_length,\n            win_length=params.win_length,\n            pad=0,\n            window_fn=torch.hann_window,\n            power=None,\n            normalized=False,\n            wkwargs=None,\n            center=True,\n            pad_mode=\"reflect\",\n            onesided=True,\n        ).to(self.device)\n\n        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html\n        self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(\n            n_fft=params.n_fft,\n            n_iter=params.num_griffin_lim_iters,\n            win_length=params.win_length,\n            hop_length=params.hop_length,\n            window_fn=torch.hann_window,\n            power=1.0,\n            wkwargs=None,\n            momentum=0.99,\n            length=None,\n            rand_init=True,\n        ).to(self.device)\n\n        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html\n        self.mel_scaler = torchaudio.transforms.MelScale(\n            n_mels=params.num_frequencies,\n            sample_rate=params.sample_rate,\n            f_min=params.min_frequency,\n            f_max=params.max_frequency,\n            n_stft=params.n_fft // 2 + 1,\n            norm=params.mel_scale_norm,\n            mel_scale=params.mel_scale_type,\n        ).to(self.device)\n\n        # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html\n        self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(\n            n_stft=params.n_fft // 2 + 1,\n            n_mels=params.num_frequencies,\n            sample_rate=params.sample_rate,\n            f_min=params.min_frequency,\n            f_max=params.max_frequency,\n            # max_iter=params.max_mel_iters, # for higher version of torchaudio\n            # tolerance_loss=1e-5, # for higher version of torchaudio\n            # tolerance_change=1e-8, # for higher version of torchaudio\n            # sgdargs=None, # for higher version of torchaudio\n            norm=params.mel_scale_norm,\n            mel_scale=params.mel_scale_type,\n        ).to(self.device)\n\n    def spectrogram_from_audio(\n        self,\n        audio: pydub.AudioSegment,\n    ) -> np.ndarray:\n        \"\"\"\n        Compute a spectrogram from an audio segment.\n\n        Args:\n            audio: Audio segment which must match the sample rate of the params\n\n        Returns:\n            spectrogram: (channel, frequency, time)\n        \"\"\"\n        assert int(audio.frame_rate) == self.p.sample_rate, \"Audio sample rate must match params\"\n\n        # Get the samples as a numpy array in (batch, samples) shape\n        waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])\n\n        # Convert to floats if necessary\n        if waveform.dtype != np.float32:\n            waveform = waveform.astype(np.float32)\n\n        waveform_tensor = torch.from_numpy(waveform).to(self.device)\n        amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)\n        return amplitudes_mel.cpu().numpy()\n\n    def audio_from_spectrogram(\n        self,\n        spectrogram: np.ndarray,\n        apply_filters: bool = True,\n    ) -> pydub.AudioSegment:\n        \"\"\"\n        Reconstruct an audio segment from a spectrogram.\n\n        Args:\n            spectrogram: (batch, frequency, time)\n            apply_filters: Post-process with normalization and compression\n\n        Returns:\n            audio: Audio segment with channels equal to the batch dimension\n        \"\"\"\n        # Move to device\n        amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)\n\n        # Reconstruct the waveform\n        waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)\n\n        # Convert to audio segment\n        segment = audio_from_waveform(\n            samples=waveform.cpu().numpy(),\n            sample_rate=self.p.sample_rate,\n            # Normalize the waveform to the range [-1, 1]\n            normalize=True,\n        )\n\n        # Optionally apply post-processing filters\n        if apply_filters:\n            segment = apply_filters_func(\n                segment,\n                compression=False,\n            )\n\n        return segment\n\n    def mel_amplitudes_from_waveform(\n        self,\n        waveform: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Torch-only function to compute Mel-scale amplitudes from a waveform.\n\n        Args:\n            waveform: (batch, samples)\n\n        Returns:\n            amplitudes_mel: (batch, frequency, time)\n        \"\"\"\n        # Compute the complex-valued spectrogram\n        spectrogram_complex = self.spectrogram_func(waveform)\n\n        # Take the magnitude\n        amplitudes = torch.abs(spectrogram_complex)\n\n        # Convert to mel scale\n        return self.mel_scaler(amplitudes)\n\n    def waveform_from_mel_amplitudes(\n        self,\n        amplitudes_mel: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.\n\n        Args:\n            amplitudes_mel: (batch, frequency, time)\n\n        Returns:\n            waveform: (batch, samples)\n        \"\"\"\n        # Convert from mel scale to linear\n        amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)\n\n        # Run the approximate algorithm to compute the phase and recover the waveform\n        return self.inverse_spectrogram_func(amplitudes_linear)\n\n\ndef check_device(device: str, backup: str = \"cpu\") -> str:\n    \"\"\"\n    Check that the device is valid and available. If not,\n    \"\"\"\n    cuda_not_found = device.lower().startswith(\"cuda\") and not torch.cuda.is_available()\n    mps_not_found = device.lower().startswith(\"mps\") and not torch.backends.mps.is_available()\n\n    if cuda_not_found or mps_not_found:\n        warnings.warn(f\"WARNING: {device} is not available, using {backup} instead.\", stacklevel=3)\n        return backup\n\n    return device\n\n\ndef audio_from_waveform(samples: np.ndarray, sample_rate: int, normalize: bool = False) -> pydub.AudioSegment:\n    \"\"\"\n    Convert a numpy array of samples of a waveform to an audio segment.\n\n    Args:\n        samples: (channels, samples) array\n    \"\"\"\n    # Normalize volume to fit in int16\n    if normalize:\n        samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))\n\n    # Transpose and convert to int16\n    samples = samples.transpose(1, 0)\n    samples = samples.astype(np.int16)\n\n    # Write to the bytes of a WAV file\n    wav_bytes = io.BytesIO()\n    wavfile.write(wav_bytes, sample_rate, samples)\n    wav_bytes.seek(0)\n\n    # Read into pydub\n    return pydub.AudioSegment.from_wav(wav_bytes)\n\n\ndef apply_filters_func(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment:\n    \"\"\"\n    Apply post-processing filters to the audio segment to compress it and\n    keep at a -10 dBFS level.\n    \"\"\"\n    # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.\n    # TODO(hayk): Is this going to make audio unbalanced between sequential clips?\n\n    if compression:\n        segment = pydub.effects.normalize(\n            segment,\n            headroom=0.1,\n        )\n\n        segment = segment.apply_gain(-10 - segment.dBFS)\n\n        # TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU\n        segment = pydub.effects.compress_dynamic_range(\n            segment,\n            threshold=-20.0,\n            ratio=4.0,\n            attack=5.0,\n            release=50.0,\n        )\n\n    desired_db = -12\n    segment = segment.apply_gain(desired_db - segment.dBFS)\n\n    segment = pydub.effects.normalize(\n        segment,\n        headroom=0.1,\n    )\n\n    return segment\n\n\ndef shave_segments(path, n_shave_prefix_segments=1):\n    \"\"\"\n    Removes segments. Positive values shave the first segments, negative shave the last segments.\n    \"\"\"\n    if n_shave_prefix_segments >= 0:\n        return \".\".join(path.split(\".\")[n_shave_prefix_segments:])\n    else:\n        return \".\".join(path.split(\".\")[:n_shave_prefix_segments])\n\n\ndef renew_resnet_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside resnets to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item.replace(\"in_layers.0\", \"norm1\")\n        new_item = new_item.replace(\"in_layers.2\", \"conv1\")\n\n        new_item = new_item.replace(\"out_layers.0\", \"norm2\")\n        new_item = new_item.replace(\"out_layers.3\", \"conv2\")\n\n        new_item = new_item.replace(\"emb_layers.1\", \"time_emb_proj\")\n        new_item = new_item.replace(\"skip_connection\", \"conv_shortcut\")\n\n        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside resnets to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item\n\n        new_item = new_item.replace(\"nin_shortcut\", \"conv_shortcut\")\n        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef renew_attention_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside attentions to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item\n\n        #         new_item = new_item.replace('norm.weight', 'group_norm.weight')\n        #         new_item = new_item.replace('norm.bias', 'group_norm.bias')\n\n        #         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')\n        #         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')\n\n        #         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside attentions to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item\n\n        new_item = new_item.replace(\"norm.weight\", \"group_norm.weight\")\n        new_item = new_item.replace(\"norm.bias\", \"group_norm.bias\")\n\n        new_item = new_item.replace(\"q.weight\", \"to_q.weight\")\n        new_item = new_item.replace(\"q.bias\", \"to_q.bias\")\n\n        new_item = new_item.replace(\"k.weight\", \"to_k.weight\")\n        new_item = new_item.replace(\"k.bias\", \"to_k.bias\")\n\n        new_item = new_item.replace(\"v.weight\", \"to_v.weight\")\n        new_item = new_item.replace(\"v.bias\", \"to_v.bias\")\n\n        new_item = new_item.replace(\"proj_out.weight\", \"to_out.0.weight\")\n        new_item = new_item.replace(\"proj_out.bias\", \"to_out.0.bias\")\n\n        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n    return mapping\n\n\ndef assign_to_checkpoint(\n    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None\n):\n    \"\"\"\n    This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits\n    attention layers, and takes into account additional replacements that may arise.\n\n    Assigns the weights to the new checkpoint.\n    \"\"\"\n    assert isinstance(paths, list), \"Paths should be a list of dicts containing 'old' and 'new' keys.\"\n\n    # Splits the attention layers into three variables.\n    if attention_paths_to_split is not None:\n        for path, path_map in attention_paths_to_split.items():\n            old_tensor = old_checkpoint[path]\n            channels = old_tensor.shape[0] // 3\n\n            target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)\n\n            num_heads = old_tensor.shape[0] // config[\"num_head_channels\"] // 3\n\n            old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])\n            query, key, value = old_tensor.split(channels // num_heads, dim=1)\n\n            checkpoint[path_map[\"query\"]] = query.reshape(target_shape)\n            checkpoint[path_map[\"key\"]] = key.reshape(target_shape)\n            checkpoint[path_map[\"value\"]] = value.reshape(target_shape)\n\n    for path in paths:\n        new_path = path[\"new\"]\n\n        # These have already been assigned\n        if attention_paths_to_split is not None and new_path in attention_paths_to_split:\n            continue\n\n        # Global renaming happens here\n        new_path = new_path.replace(\"middle_block.0\", \"mid_block.resnets.0\")\n        new_path = new_path.replace(\"middle_block.1\", \"mid_block.attentions.0\")\n        new_path = new_path.replace(\"middle_block.2\", \"mid_block.resnets.1\")\n\n        if additional_replacements is not None:\n            for replacement in additional_replacements:\n                new_path = new_path.replace(replacement[\"old\"], replacement[\"new\"])\n\n        # proj_attn.weight has to be converted from conv 1D to linear\n        if \"proj_attn.weight\" in new_path:\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]][:, :, 0]\n        elif \"to_out.0.weight\" in new_path:\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]].squeeze()\n        elif any(qkv in new_path for qkv in [\"to_q\", \"to_k\", \"to_v\"]):\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]].squeeze()\n        else:\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]]\n\n\ndef conv_attn_to_linear(checkpoint):\n    keys = list(checkpoint.keys())\n    attn_keys = [\"query.weight\", \"key.weight\", \"value.weight\"]\n    for key in keys:\n        if \".\".join(key.split(\".\")[-2:]) in attn_keys:\n            if checkpoint[key].ndim > 2:\n                checkpoint[key] = checkpoint[key][:, :, 0, 0]\n        elif \"proj_attn.weight\" in key:\n            if checkpoint[key].ndim > 2:\n                checkpoint[key] = checkpoint[key][:, :, 0]\n\n\ndef create_unet_diffusers_config(original_config, image_size: int, controlnet=False):\n    \"\"\"\n    Creates a config for the diffusers based on the config of the LDM model.\n    \"\"\"\n    if controlnet:\n        unet_params = original_config.model.params.control_stage_config.params\n    else:\n        unet_params = original_config.model.params.unet_config.params\n\n    vae_params = original_config.model.params.first_stage_config.params.ddconfig\n\n    block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]\n\n    down_block_types = []\n    resolution = 1\n    for i in range(len(block_out_channels)):\n        block_type = \"CrossAttnDownBlock2D\" if resolution in unet_params.attention_resolutions else \"DownBlock2D\"\n        down_block_types.append(block_type)\n        if i != len(block_out_channels) - 1:\n            resolution *= 2\n\n    up_block_types = []\n    for i in range(len(block_out_channels)):\n        block_type = \"CrossAttnUpBlock2D\" if resolution in unet_params.attention_resolutions else \"UpBlock2D\"\n        up_block_types.append(block_type)\n        resolution //= 2\n\n    vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)\n\n    head_dim = unet_params.num_heads if \"num_heads\" in unet_params else None\n    use_linear_projection = (\n        unet_params.use_linear_in_transformer if \"use_linear_in_transformer\" in unet_params else False\n    )\n    if use_linear_projection:\n        # stable diffusion 2-base-512 and 2-768\n        if head_dim is None:\n            head_dim = [5, 10, 20, 20]\n\n    class_embed_type = None\n    projection_class_embeddings_input_dim = None\n\n    if \"num_classes\" in unet_params:\n        if unet_params.num_classes == \"sequential\":\n            class_embed_type = \"projection\"\n            assert \"adm_in_channels\" in unet_params\n            projection_class_embeddings_input_dim = unet_params.adm_in_channels\n        else:\n            raise NotImplementedError(f\"Unknown conditional unet num_classes config: {unet_params.num_classes}\")\n\n    config = {\n        \"sample_size\": image_size // vae_scale_factor,\n        \"in_channels\": unet_params.in_channels,\n        \"down_block_types\": tuple(down_block_types),\n        \"block_out_channels\": tuple(block_out_channels),\n        \"layers_per_block\": unet_params.num_res_blocks,\n        \"cross_attention_dim\": unet_params.context_dim,\n        \"attention_head_dim\": head_dim,\n        \"use_linear_projection\": use_linear_projection,\n        \"class_embed_type\": class_embed_type,\n        \"projection_class_embeddings_input_dim\": projection_class_embeddings_input_dim,\n    }\n\n    if not controlnet:\n        config[\"out_channels\"] = unet_params.out_channels\n        config[\"up_block_types\"] = tuple(up_block_types)\n\n    return config\n\n\ndef create_vae_diffusers_config(original_config, image_size: int):\n    \"\"\"\n    Creates a config for the diffusers based on the config of the LDM model.\n    \"\"\"\n    vae_params = original_config.model.params.first_stage_config.params.ddconfig\n    _ = original_config.model.params.first_stage_config.params.embed_dim\n\n    block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]\n    down_block_types = [\"DownEncoderBlock2D\"] * len(block_out_channels)\n    up_block_types = [\"UpDecoderBlock2D\"] * len(block_out_channels)\n\n    config = {\n        \"sample_size\": image_size,\n        \"in_channels\": vae_params.in_channels,\n        \"out_channels\": vae_params.out_ch,\n        \"down_block_types\": tuple(down_block_types),\n        \"up_block_types\": tuple(up_block_types),\n        \"block_out_channels\": tuple(block_out_channels),\n        \"latent_channels\": vae_params.z_channels,\n        \"layers_per_block\": vae_params.num_res_blocks,\n    }\n    return config\n\n\ndef create_diffusers_schedular(original_config):\n    schedular = DDIMScheduler(\n        num_train_timesteps=original_config.model.params.timesteps,\n        beta_start=original_config.model.params.linear_start,\n        beta_end=original_config.model.params.linear_end,\n        beta_schedule=\"scaled_linear\",\n    )\n    return schedular\n\n\ndef convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):\n    \"\"\"\n    Takes a state dict and a config, and returns a converted checkpoint.\n    \"\"\"\n\n    # extract state_dict for UNet\n    unet_state_dict = {}\n    keys = list(checkpoint.keys())\n\n    if controlnet:\n        unet_key = \"control_model.\"\n    else:\n        unet_key = \"model.diffusion_model.\"\n\n    # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA\n    if sum(k.startswith(\"model_ema\") for k in keys) > 100 and extract_ema:\n        print(f\"Checkpoint {path} has both EMA and non-EMA weights.\")\n        print(\n            \"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA\"\n            \" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag.\"\n        )\n        for key in keys:\n            if key.startswith(\"model.diffusion_model\"):\n                flat_ema_key = \"model_ema.\" + \"\".join(key.split(\".\")[1:])\n                unet_state_dict[key.replace(unet_key, \"\")] = checkpoint.pop(flat_ema_key)\n    else:\n        if sum(k.startswith(\"model_ema\") for k in keys) > 100:\n            print(\n                \"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA\"\n                \" weights (usually better for inference), please make sure to add the `--extract_ema` flag.\"\n            )\n\n        for key in keys:\n            if key.startswith(unet_key):\n                unet_state_dict[key.replace(unet_key, \"\")] = checkpoint.pop(key)\n\n    new_checkpoint = {}\n\n    new_checkpoint[\"time_embedding.linear_1.weight\"] = unet_state_dict[\"time_embed.0.weight\"]\n    new_checkpoint[\"time_embedding.linear_1.bias\"] = unet_state_dict[\"time_embed.0.bias\"]\n    new_checkpoint[\"time_embedding.linear_2.weight\"] = unet_state_dict[\"time_embed.2.weight\"]\n    new_checkpoint[\"time_embedding.linear_2.bias\"] = unet_state_dict[\"time_embed.2.bias\"]\n\n    if config[\"class_embed_type\"] is None:\n        # No parameters to port\n        ...\n    elif config[\"class_embed_type\"] == \"timestep\" or config[\"class_embed_type\"] == \"projection\":\n        new_checkpoint[\"class_embedding.linear_1.weight\"] = unet_state_dict[\"label_emb.0.0.weight\"]\n        new_checkpoint[\"class_embedding.linear_1.bias\"] = unet_state_dict[\"label_emb.0.0.bias\"]\n        new_checkpoint[\"class_embedding.linear_2.weight\"] = unet_state_dict[\"label_emb.0.2.weight\"]\n        new_checkpoint[\"class_embedding.linear_2.bias\"] = unet_state_dict[\"label_emb.0.2.bias\"]\n    else:\n        raise NotImplementedError(f\"Not implemented `class_embed_type`: {config['class_embed_type']}\")\n\n    new_checkpoint[\"conv_in.weight\"] = unet_state_dict[\"input_blocks.0.0.weight\"]\n    new_checkpoint[\"conv_in.bias\"] = unet_state_dict[\"input_blocks.0.0.bias\"]\n\n    if not controlnet:\n        new_checkpoint[\"conv_norm_out.weight\"] = unet_state_dict[\"out.0.weight\"]\n        new_checkpoint[\"conv_norm_out.bias\"] = unet_state_dict[\"out.0.bias\"]\n        new_checkpoint[\"conv_out.weight\"] = unet_state_dict[\"out.2.weight\"]\n        new_checkpoint[\"conv_out.bias\"] = unet_state_dict[\"out.2.bias\"]\n\n    # Retrieves the keys for the input blocks only\n    num_input_blocks = len({\".\".join(layer.split(\".\")[:2]) for layer in unet_state_dict if \"input_blocks\" in layer})\n    input_blocks = {\n        layer_id: [key for key in unet_state_dict if f\"input_blocks.{layer_id}\" in key]\n        for layer_id in range(num_input_blocks)\n    }\n\n    # Retrieves the keys for the middle blocks only\n    num_middle_blocks = len({\".\".join(layer.split(\".\")[:2]) for layer in unet_state_dict if \"middle_block\" in layer})\n    middle_blocks = {\n        layer_id: [key for key in unet_state_dict if f\"middle_block.{layer_id}\" in key]\n        for layer_id in range(num_middle_blocks)\n    }\n\n    # Retrieves the keys for the output blocks only\n    num_output_blocks = len({\".\".join(layer.split(\".\")[:2]) for layer in unet_state_dict if \"output_blocks\" in layer})\n    output_blocks = {\n        layer_id: [key for key in unet_state_dict if f\"output_blocks.{layer_id}\" in key]\n        for layer_id in range(num_output_blocks)\n    }\n\n    for i in range(1, num_input_blocks):\n        block_id = (i - 1) // (config[\"layers_per_block\"] + 1)\n        layer_in_block_id = (i - 1) % (config[\"layers_per_block\"] + 1)\n\n        resnets = [\n            key for key in input_blocks[i] if f\"input_blocks.{i}.0\" in key and f\"input_blocks.{i}.0.op\" not in key\n        ]\n        attentions = [key for key in input_blocks[i] if f\"input_blocks.{i}.1\" in key]\n\n        if f\"input_blocks.{i}.0.op.weight\" in unet_state_dict:\n            new_checkpoint[f\"down_blocks.{block_id}.downsamplers.0.conv.weight\"] = unet_state_dict.pop(\n                f\"input_blocks.{i}.0.op.weight\"\n            )\n            new_checkpoint[f\"down_blocks.{block_id}.downsamplers.0.conv.bias\"] = unet_state_dict.pop(\n                f\"input_blocks.{i}.0.op.bias\"\n            )\n\n        paths = renew_resnet_paths(resnets)\n        meta_path = {\"old\": f\"input_blocks.{i}.0\", \"new\": f\"down_blocks.{block_id}.resnets.{layer_in_block_id}\"}\n        assign_to_checkpoint(\n            paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n        )\n\n        if len(attentions):\n            paths = renew_attention_paths(attentions)\n            meta_path = {\"old\": f\"input_blocks.{i}.1\", \"new\": f\"down_blocks.{block_id}.attentions.{layer_in_block_id}\"}\n            assign_to_checkpoint(\n                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n            )\n\n    resnet_0 = middle_blocks[0]\n    attentions = middle_blocks[1]\n    resnet_1 = middle_blocks[2]\n\n    resnet_0_paths = renew_resnet_paths(resnet_0)\n    assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)\n\n    resnet_1_paths = renew_resnet_paths(resnet_1)\n    assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)\n\n    attentions_paths = renew_attention_paths(attentions)\n    meta_path = {\"old\": \"middle_block.1\", \"new\": \"mid_block.attentions.0\"}\n    assign_to_checkpoint(\n        attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n    )\n\n    for i in range(num_output_blocks):\n        block_id = i // (config[\"layers_per_block\"] + 1)\n        layer_in_block_id = i % (config[\"layers_per_block\"] + 1)\n        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]\n        output_block_list = {}\n\n        for layer in output_block_layers:\n            layer_id, layer_name = layer.split(\".\")[0], shave_segments(layer, 1)\n            if layer_id in output_block_list:\n                output_block_list[layer_id].append(layer_name)\n            else:\n                output_block_list[layer_id] = [layer_name]\n\n        if len(output_block_list) > 1:\n            resnets = [key for key in output_blocks[i] if f\"output_blocks.{i}.0\" in key]\n            attentions = [key for key in output_blocks[i] if f\"output_blocks.{i}.1\" in key]\n\n            resnet_0_paths = renew_resnet_paths(resnets)\n            paths = renew_resnet_paths(resnets)\n\n            meta_path = {\"old\": f\"output_blocks.{i}.0\", \"new\": f\"up_blocks.{block_id}.resnets.{layer_in_block_id}\"}\n            assign_to_checkpoint(\n                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n            )\n\n            output_block_list = {k: sorted(v) for k, v in output_block_list.items()}\n            if [\"conv.bias\", \"conv.weight\"] in output_block_list.values():\n                index = list(output_block_list.values()).index([\"conv.bias\", \"conv.weight\"])\n                new_checkpoint[f\"up_blocks.{block_id}.upsamplers.0.conv.weight\"] = unet_state_dict[\n                    f\"output_blocks.{i}.{index}.conv.weight\"\n                ]\n                new_checkpoint[f\"up_blocks.{block_id}.upsamplers.0.conv.bias\"] = unet_state_dict[\n                    f\"output_blocks.{i}.{index}.conv.bias\"\n                ]\n\n                # Clear attentions as they have been attributed above.\n                if len(attentions) == 2:\n                    attentions = []\n\n            if len(attentions):\n                paths = renew_attention_paths(attentions)\n                meta_path = {\n                    \"old\": f\"output_blocks.{i}.1\",\n                    \"new\": f\"up_blocks.{block_id}.attentions.{layer_in_block_id}\",\n                }\n                assign_to_checkpoint(\n                    paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n                )\n        else:\n            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)\n            for path in resnet_0_paths:\n                old_path = \".\".join([\"output_blocks\", str(i), path[\"old\"]])\n                new_path = \".\".join([\"up_blocks\", str(block_id), \"resnets\", str(layer_in_block_id), path[\"new\"]])\n\n                new_checkpoint[new_path] = unet_state_dict[old_path]\n\n    if controlnet:\n        # conditioning embedding\n\n        orig_index = 0\n\n        new_checkpoint[\"controlnet_cond_embedding.conv_in.weight\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_cond_embedding.conv_in.bias\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.bias\"\n        )\n\n        orig_index += 2\n\n        diffusers_index = 0\n\n        while diffusers_index < 6:\n            new_checkpoint[f\"controlnet_cond_embedding.blocks.{diffusers_index}.weight\"] = unet_state_dict.pop(\n                f\"input_hint_block.{orig_index}.weight\"\n            )\n            new_checkpoint[f\"controlnet_cond_embedding.blocks.{diffusers_index}.bias\"] = unet_state_dict.pop(\n                f\"input_hint_block.{orig_index}.bias\"\n            )\n            diffusers_index += 1\n            orig_index += 2\n\n        new_checkpoint[\"controlnet_cond_embedding.conv_out.weight\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_cond_embedding.conv_out.bias\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.bias\"\n        )\n\n        # down blocks\n        for i in range(num_input_blocks):\n            new_checkpoint[f\"controlnet_down_blocks.{i}.weight\"] = unet_state_dict.pop(f\"zero_convs.{i}.0.weight\")\n            new_checkpoint[f\"controlnet_down_blocks.{i}.bias\"] = unet_state_dict.pop(f\"zero_convs.{i}.0.bias\")\n\n        # mid block\n        new_checkpoint[\"controlnet_mid_block.weight\"] = unet_state_dict.pop(\"middle_block_out.0.weight\")\n        new_checkpoint[\"controlnet_mid_block.bias\"] = unet_state_dict.pop(\"middle_block_out.0.bias\")\n\n    return new_checkpoint\n\n\ndef convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False):\n    # extract state dict for VAE\n    vae_state_dict = {}\n    vae_key = \"first_stage_model.\"\n    keys = list(checkpoint.keys())\n    for key in keys:\n        if key.startswith(vae_key):\n            vae_state_dict[key.replace(vae_key, \"\")] = checkpoint.get(key)\n\n    new_checkpoint = {}\n\n    new_checkpoint[\"encoder.conv_in.weight\"] = vae_state_dict[\"encoder.conv_in.weight\"]\n    new_checkpoint[\"encoder.conv_in.bias\"] = vae_state_dict[\"encoder.conv_in.bias\"]\n    new_checkpoint[\"encoder.conv_out.weight\"] = vae_state_dict[\"encoder.conv_out.weight\"]\n    new_checkpoint[\"encoder.conv_out.bias\"] = vae_state_dict[\"encoder.conv_out.bias\"]\n    new_checkpoint[\"encoder.conv_norm_out.weight\"] = vae_state_dict[\"encoder.norm_out.weight\"]\n    new_checkpoint[\"encoder.conv_norm_out.bias\"] = vae_state_dict[\"encoder.norm_out.bias\"]\n\n    new_checkpoint[\"decoder.conv_in.weight\"] = vae_state_dict[\"decoder.conv_in.weight\"]\n    new_checkpoint[\"decoder.conv_in.bias\"] = vae_state_dict[\"decoder.conv_in.bias\"]\n    new_checkpoint[\"decoder.conv_out.weight\"] = vae_state_dict[\"decoder.conv_out.weight\"]\n    new_checkpoint[\"decoder.conv_out.bias\"] = vae_state_dict[\"decoder.conv_out.bias\"]\n    new_checkpoint[\"decoder.conv_norm_out.weight\"] = vae_state_dict[\"decoder.norm_out.weight\"]\n    new_checkpoint[\"decoder.conv_norm_out.bias\"] = vae_state_dict[\"decoder.norm_out.bias\"]\n\n    new_checkpoint[\"quant_conv.weight\"] = vae_state_dict[\"quant_conv.weight\"]\n    new_checkpoint[\"quant_conv.bias\"] = vae_state_dict[\"quant_conv.bias\"]\n    new_checkpoint[\"post_quant_conv.weight\"] = vae_state_dict[\"post_quant_conv.weight\"]\n    new_checkpoint[\"post_quant_conv.bias\"] = vae_state_dict[\"post_quant_conv.bias\"]\n\n    # Retrieves the keys for the encoder down blocks only\n    num_down_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"encoder.down\" in layer})\n    down_blocks = {\n        layer_id: [key for key in vae_state_dict if f\"down.{layer_id}\" in key] for layer_id in range(num_down_blocks)\n    }\n\n    # Retrieves the keys for the decoder up blocks only\n    num_up_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"decoder.up\" in layer})\n    up_blocks = {\n        layer_id: [key for key in vae_state_dict if f\"up.{layer_id}\" in key] for layer_id in range(num_up_blocks)\n    }\n\n    for i in range(num_down_blocks):\n        resnets = [key for key in down_blocks[i] if f\"down.{i}\" in key and f\"down.{i}.downsample\" not in key]\n\n        if f\"encoder.down.{i}.downsample.conv.weight\" in vae_state_dict:\n            new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.weight\"] = vae_state_dict.pop(\n                f\"encoder.down.{i}.downsample.conv.weight\"\n            )\n            new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.bias\"] = vae_state_dict.pop(\n                f\"encoder.down.{i}.downsample.conv.bias\"\n            )\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"down.{i}.block\", \"new\": f\"down_blocks.{i}.resnets\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_resnets = [key for key in vae_state_dict if \"encoder.mid.block\" in key]\n    num_mid_res_blocks = 2\n    for i in range(1, num_mid_res_blocks + 1):\n        resnets = [key for key in mid_resnets if f\"encoder.mid.block_{i}\" in key]\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_attentions = [key for key in vae_state_dict if \"encoder.mid.attn\" in key]\n    paths = renew_vae_attention_paths(mid_attentions)\n    meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n    conv_attn_to_linear(new_checkpoint)\n\n    for i in range(num_up_blocks):\n        block_id = num_up_blocks - 1 - i\n        resnets = [\n            key for key in up_blocks[block_id] if f\"up.{block_id}\" in key and f\"up.{block_id}.upsample\" not in key\n        ]\n\n        if f\"decoder.up.{block_id}.upsample.conv.weight\" in vae_state_dict:\n            new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.weight\"] = vae_state_dict[\n                f\"decoder.up.{block_id}.upsample.conv.weight\"\n            ]\n            new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.bias\"] = vae_state_dict[\n                f\"decoder.up.{block_id}.upsample.conv.bias\"\n            ]\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"up.{block_id}.block\", \"new\": f\"up_blocks.{i}.resnets\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_resnets = [key for key in vae_state_dict if \"decoder.mid.block\" in key]\n    num_mid_res_blocks = 2\n    for i in range(1, num_mid_res_blocks + 1):\n        resnets = [key for key in mid_resnets if f\"decoder.mid.block_{i}\" in key]\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_attentions = [key for key in vae_state_dict if \"decoder.mid.attn\" in key]\n    paths = renew_vae_attention_paths(mid_attentions)\n    meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n    conv_attn_to_linear(new_checkpoint)\n\n    if only_decoder:\n        new_checkpoint = {\n            k: v for k, v in new_checkpoint.items() if k.startswith(\"decoder\") or k.startswith(\"post_quant\")\n        }\n    elif only_encoder:\n        new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith(\"encoder\") or k.startswith(\"quant\")}\n\n    return new_checkpoint\n\n\ndef convert_ldm_clip_checkpoint(checkpoint):\n    keys = list(checkpoint.keys())\n\n    text_model_dict = {}\n    for key in keys:\n        if key.startswith(\"cond_stage_model.transformer\"):\n            text_model_dict[key[len(\"cond_stage_model.transformer.\") :]] = checkpoint[key]\n\n    return text_model_dict\n\n\ndef convert_lora_model_level(\n    state_dict, unet, text_encoder=None, LORA_PREFIX_UNET=\"lora_unet\", LORA_PREFIX_TEXT_ENCODER=\"lora_te\", alpha=0.6\n):\n    \"\"\"convert lora in model level instead of pipeline leval\"\"\"\n\n    visited = []\n\n    # directly update weight in diffusers model\n    for key in state_dict:\n        # it is suggested to print out the key, it usually will be something like below\n        # \"lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight\"\n\n        # as we have set the alpha beforehand, so just skip\n        if \".alpha\" in key or key in visited:\n            continue\n\n        if \"text\" in key:\n            layer_infos = key.split(\".\")[0].split(LORA_PREFIX_TEXT_ENCODER + \"_\")[-1].split(\"_\")\n            assert text_encoder is not None, \"text_encoder must be passed since lora contains text encoder layers\"\n            curr_layer = text_encoder\n        else:\n            layer_infos = key.split(\".\")[0].split(LORA_PREFIX_UNET + \"_\")[-1].split(\"_\")\n            curr_layer = unet\n\n        # find the target layer\n        temp_name = layer_infos.pop(0)\n        while len(layer_infos) > -1:\n            try:\n                curr_layer = curr_layer.__getattr__(temp_name)\n                if len(layer_infos) > 0:\n                    temp_name = layer_infos.pop(0)\n                elif len(layer_infos) == 0:\n                    break\n            except Exception:\n                if len(temp_name) > 0:\n                    temp_name += \"_\" + layer_infos.pop(0)\n                else:\n                    temp_name = layer_infos.pop(0)\n\n        pair_keys = []\n        if \"lora_down\" in key:\n            pair_keys.append(key.replace(\"lora_down\", \"lora_up\"))\n            pair_keys.append(key)\n        else:\n            pair_keys.append(key)\n            pair_keys.append(key.replace(\"lora_up\", \"lora_down\"))\n\n        # update weight\n        # NOTE: load lycon, maybe have bugs :(\n        if \"conv_in\" in pair_keys[0]:\n            weight_up = state_dict[pair_keys[0]].to(torch.float32)\n            weight_down = state_dict[pair_keys[1]].to(torch.float32)\n            weight_up = weight_up.view(weight_up.size(0), -1)\n            weight_down = weight_down.view(weight_down.size(0), -1)\n            shape = list(curr_layer.weight.data.shape)\n            shape[1] = 4\n            curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape)\n        elif \"conv\" in pair_keys[0]:\n            weight_up = state_dict[pair_keys[0]].to(torch.float32)\n            weight_down = state_dict[pair_keys[1]].to(torch.float32)\n            weight_up = weight_up.view(weight_up.size(0), -1)\n            weight_down = weight_down.view(weight_down.size(0), -1)\n            shape = list(curr_layer.weight.data.shape)\n            curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape)\n        elif len(state_dict[pair_keys[0]].shape) == 4:\n            weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)\n            weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)\n            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(\n                curr_layer.weight.data.device\n            )\n        else:\n            weight_up = state_dict[pair_keys[0]].to(torch.float32)\n            weight_down = state_dict[pair_keys[1]].to(torch.float32)\n            curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)\n\n        # update visited list\n        for item in pair_keys:\n            visited.append(item)\n\n    return unet, text_encoder\n\n\ndef denormalize_spectrogram(\n    data: torch.Tensor,\n    max_value: float = 200,\n    min_value: float = 1e-5,\n    power: float = 1,\n    inverse: bool = False,\n) -> torch.Tensor:\n    max_value = np.log(max_value)\n    min_value = np.log(min_value)\n\n    # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner\n    data = torch.flip(data, [1])\n\n    assert len(data.shape) == 3, \"Expected 3 dimensions, got {}\".format(len(data.shape))\n\n    if data.shape[0] == 1:\n        data = data.repeat(3, 1, 1)\n\n    assert data.shape[0] == 3, \"Expected 3 channels, got {}\".format(data.shape[0])\n    data = data[0]\n\n    # Reverse the power curve\n    data = torch.pow(data, 1 / power)\n\n    # Invert\n    if inverse:\n        data = 1 - data\n\n    # Rescale to max value\n    spectrogram = data * (max_value - min_value) + min_value\n\n    return spectrogram\n\n\nclass ToTensor1D(torchvision.transforms.ToTensor):\n    def __call__(self, tensor: np.ndarray):\n        tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis])\n\n        return tensor_2d.squeeze_(0)\n\n\ndef scale(old_value, old_min, old_max, new_min, new_max):\n    old_range = old_max - old_min\n    new_range = new_max - new_min\n    new_value = (((old_value - old_min) * new_range) / old_range) + new_min\n\n    return new_value\n\n\ndef read_frames_with_moviepy(video_path, max_frame_nums=None):\n    clip = VideoFileClip(video_path)\n    duration = clip.duration\n    frames = []\n    for frame in clip.iter_frames():\n        frames.append(frame)\n    if max_frame_nums is not None:\n        frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int)\n    return np.array(frames)[frames_idx, ...], duration\n\n\ndef read_frames_with_moviepy_resample(video_path, save_path):\n    vision_transform_list = [\n        transforms.Resize((128, 128)),\n        transforms.CenterCrop((112, 112)),\n        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n    ]\n    video_transform = transforms.Compose(vision_transform_list)\n    os.makedirs(save_path, exist_ok=True)\n    command = f'ffmpeg -v quiet -y -i \"{video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg'\n    os.system(command)\n    frame_list = glob.glob(f\"{save_path}/*.jpg\")\n    frame_list.sort()\n    convert_tensor = transforms.ToTensor()\n    frame_list = [convert_tensor(np.array(Image.open(frame))) for frame in frame_list]\n    imgs = torch.stack(frame_list, dim=0)\n    imgs = video_transform(imgs)\n    imgs = imgs.permute(1, 0, 2, 3)\n    return imgs\n"
  },
  {
    "path": "inference.py",
    "content": "import argparse\nimport glob\nimport os\nimport os.path as osp\nfrom pathlib import Path\n\nimport soundfile as sf\nimport torch\nimport torchvision\nfrom huggingface_hub import snapshot_download\nfrom moviepy.editor import AudioFileClip, VideoFileClip\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom foleycrafter.models.onset import torch_utils\nfrom foleycrafter.models.time_detector.model import VideoOnsetNet\nfrom foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram\nfrom foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy\n\n\nvision_transform_list = [\n    torchvision.transforms.Resize((128, 128)),\n    torchvision.transforms.CenterCrop((112, 112)),\n    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n]\nvideo_transform = torchvision.transforms.Compose(vision_transform_list)\n\n\ndef args_parse():\n    config = argparse.ArgumentParser()\n    config.add_argument(\"--prompt\", type=str, default=\"\", help=\"prompt for audio generation\")\n    config.add_argument(\"--nprompt\", type=str, default=\"\", help=\"negative prompt for audio generation\")\n    config.add_argument(\"--seed\", type=int, default=42, help=\"ramdom seed\")\n    config.add_argument(\"--semantic_scale\", type=float, default=1.0, help=\"visual content scale\")\n    config.add_argument(\"--temporal_scale\", type=float, default=0.2, help=\"temporal align scale\")\n    config.add_argument(\"--input\", type=str, default=\"examples/sora\", help=\"input video folder path\")\n    config.add_argument(\"--ckpt\", type=str, default=\"checkpoints/\", help=\"checkpoints folder path\")\n    config.add_argument(\"--save_dir\", type=str, default=\"output/\", help=\"generation result save path\")\n    config.add_argument(\n        \"--pretrain\",\n        type=str,\n        default=\"auffusion/auffusion-full-no-adapter\",\n        help=\"audio generator pretrained checkpoint path\",\n    )\n    config.add_argument(\"--device\", type=str, default=\"cuda\")\n    config = config.parse_args()\n    return config\n\n\ndef build_models(config):\n    # download ckpt\n    pretrained_model_name_or_path = config.pretrain\n    if not os.path.isdir(pretrained_model_name_or_path):\n        pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)\n\n    fc_ckpt = \"ymzhang319/FoleyCrafter\"\n    if not os.path.isdir(fc_ckpt):\n        fc_ckpt = snapshot_download(fc_ckpt, local_dir=config.ckpt)\n\n    # ckpt path\n    temporal_ckpt_path = osp.join(config.ckpt, \"temporal_adapter.ckpt\")\n\n    # load vocoder\n    vocoder_config_path = fc_ckpt\n    vocoder = Generator.from_pretrained(vocoder_config_path, subfolder=\"vocoder\").to(config.device)\n\n    # load time_detector\n    time_detector_ckpt = osp.join(osp.join(config.ckpt, \"timestamp_detector.pth.tar\"))\n    time_detector = VideoOnsetNet(False)\n    time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, device=config.device, strict=True)\n\n    # load adapters\n    pipe = build_foleycrafter().to(config.device)\n    ckpt = torch.load(temporal_ckpt_path)\n\n    # load temporal adapter\n    if \"state_dict\" in ckpt.keys():\n        ckpt = ckpt[\"state_dict\"]\n    load_gligen_ckpt = {}\n    for key, value in ckpt.items():\n        if key.startswith(\"module.\"):\n            load_gligen_ckpt[key[len(\"module.\") :]] = value\n        else:\n            load_gligen_ckpt[key] = value\n    m, u = pipe.controlnet.load_state_dict(load_gligen_ckpt, strict=False)\n    print(f\"### Control Net missing keys: {len(m)}; \\n### unexpected keys: {len(u)};\")\n\n    # load semantic adapter\n    pipe.load_ip_adapter(\n        osp.join(config.ckpt, \"semantic\"), subfolder=\"\", weight_name=\"semantic_adapter.bin\", image_encoder_folder=None\n    )\n    ip_adapter_weight = config.semantic_scale\n    pipe.set_ip_adapter_scale(ip_adapter_weight)\n\n    return pipe, vocoder, time_detector\n\n\ndef run_inference(config, pipe, vocoder, time_detector):\n    controlnet_conditioning_scale = config.temporal_scale\n    os.makedirs(config.save_dir, exist_ok=True)\n\n    input_list = glob.glob(f\"{config.input}/*.mp4\")\n    assert len(input_list) != 0, \"input list is empty!\"\n\n    generator = torch.Generator(device=config.device)\n    generator.manual_seed(config.seed)\n    image_processor = CLIPImageProcessor()\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        \"h94/IP-Adapter\", subfolder=\"models/image_encoder\"\n    ).to(config.device)\n    input_list.sort()\n    with torch.no_grad():\n        for input_video in input_list:\n            print(f\" >>> Begin Inference: {input_video} <<< \")\n            frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=150)\n\n            time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2)\n            time_frames = video_transform(time_frames)\n            time_frames = {\"frames\": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}\n            preds = time_detector(time_frames)\n            preds = torch.sigmoid(preds)\n\n            # duration\n            # import ipdb; ipdb.set_trace()\n            time_condition = [\n                -1 if preds[0][int(i / (1024 / 10 * duration) * 150)] < 0.5 else 1\n                for i in range(int(1024 / 10 * duration))\n            ]\n            time_condition = time_condition + [-1] * (1024 - len(time_condition))\n            # w -> b c h w\n            time_condition = (\n                torch.FloatTensor(time_condition)\n                .unsqueeze(0)\n                .unsqueeze(0)\n                .unsqueeze(0)\n                .repeat(1, 1, 256, 1)\n                .to(\"cuda\")\n            )\n            images = image_processor(images=frames, return_tensors=\"pt\").to(\"cuda\")\n            image_embeddings = image_encoder(**images).image_embeds\n            image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)\n            neg_image_embeddings = torch.zeros_like(image_embeddings)\n            image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)\n\n            name = Path(input_video).stem\n            name = name.replace(\"+\", \" \")\n\n            sample = pipe(\n                prompt=config.prompt,\n                negative_prompt=config.nprompt,\n                ip_adapter_image_embeds=image_embeddings,\n                image=time_condition,\n                # audio_length_in_s=10,\n                controlnet_conditioning_scale=controlnet_conditioning_scale,\n                num_inference_steps=25,\n                height=256,\n                width=1024,\n                output_type=\"pt\",\n                generator=generator,\n                # guidance_scale=0,\n            )\n            audio_img = sample.images[0]\n            audio = denormalize_spectrogram(audio_img)\n            audio = vocoder.inference(audio, lengths=160000)[0]\n            audio_save_path = osp.join(config.save_dir, \"audio\")\n            video_save_path = osp.join(config.save_dir, \"video\")\n            os.makedirs(audio_save_path, exist_ok=True)\n            os.makedirs(video_save_path, exist_ok=True)\n            audio = audio[: int(duration * 16000)]\n\n            save_path = osp.join(audio_save_path, f\"{name}.wav\")\n            sf.write(save_path, audio, 16000)\n\n            audio = AudioFileClip(osp.join(audio_save_path, f\"{name}.wav\"))\n            video = VideoFileClip(input_video)\n            audio = audio.subclip(0, duration)\n            video.audio = audio\n            video = video.subclip(0, duration)\n            os.makedirs(video_save_path, exist_ok=True)\n            video.write_videofile(osp.join(video_save_path, f\"{name}.mp4\"))\n\n\nif __name__ == \"__main__\":\n    config = args_parse()\n    pipe, vocoder, time_detector = build_models(config)\n    run_inference(config, pipe, vocoder, time_detector)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.ruff]\n# Never enforce `E501` (line length violations).\nignore = [\"C901\", \"E501\", \"E741\", \"F402\", \"F823\"]\nselect = [\"C\", \"E\", \"F\", \"I\", \"W\"]\nline-length = 119\n\n# Ignore import violations in all `__init__.py` files.\n[tool.ruff.per-file-ignores]\n\"__init__.py\" = [\"E402\", \"F401\", \"F403\", \"F811\"]\n\"src/diffusers/utils/dummy_*.py\" = [\"F401\"]\n\n[tool.ruff.isort]\nlines-after-imports = 2\nknown-first-party = [\"diffusers\"]\n\n[tool.ruff.format]\n# Like Black, use double quotes for strings.\nquote-style = \"double\"\n\n# Like Black, indent with spaces, rather than tabs.\nindent-style = \"space\"\n\n# Like Black, respect magic trailing commas.\nskip-magic-trailing-comma = false\n\n# Like Black, automatically detect the appropriate line ending.\nline-ending = \"auto\"\n"
  },
  {
    "path": "requirements/environment.yaml",
    "content": "name: foleycrafter\nchannels:\n  - pytorch\n  - nvidia\ndependencies:\n  - python=3.10\n  - pytorch=2.2.0\n  - torchvision=0.17.0\n  - pytorch-cuda=11.8\n  - pip\n  - pip:\n    - diffusers==0.25.1\n    - transformers==4.30.2\n    - xformers\n    - imageio==2.33.1\n    - decord==0.6.0\n    - einops\n    - omegaconf\n    - safetensors\n    - gradio\n    - tqdm==4.66.1\n    - soundfile==0.12.1\n    - wandb\n    - moviepy==1.0.3\n    - kornia==0.7.1\n    - h5py==3.7.0\n"
  }
]