Full Code of open-mmlab/FoleyCrafter for AI

main c2399fa3cb7e cached
37 files
869.2 KB
201.2k tokens
700 symbols
1 requests
Download .txt
Showing preview only (897K chars total). Download the full file or copy to clipboard to get everything.
Repository: open-mmlab/FoleyCrafter
Branch: main
Commit: c2399fa3cb7e
Files: 37
Total size: 869.2 KB

Directory structure:
gitextract_36et0ept/

├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── app.py
├── foleycrafter/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── video_transforms.py
│   ├── models/
│   │   ├── adapters/
│   │   │   ├── attention_processor.py
│   │   │   ├── ip_adapter.py
│   │   │   ├── resampler.py
│   │   │   ├── transformer.py
│   │   │   └── utils.py
│   │   ├── auffusion/
│   │   │   ├── attention.py
│   │   │   ├── attention_processor.py
│   │   │   ├── dual_transformer_2d.py
│   │   │   ├── loaders/
│   │   │   │   ├── ip_adapter.py
│   │   │   │   └── unet.py
│   │   │   ├── resnet.py
│   │   │   ├── transformer_2d.py
│   │   │   └── unet_2d_blocks.py
│   │   ├── auffusion_unet.py
│   │   ├── onset/
│   │   │   ├── __init__.py
│   │   │   ├── r2plus1d_18.py
│   │   │   ├── resnet.py
│   │   │   ├── torch_utils.py
│   │   │   └── video_onset_net.py
│   │   └── time_detector/
│   │       ├── model.py
│   │       └── resnet.py
│   ├── pipelines/
│   │   ├── auffusion_pipeline.py
│   │   └── pipeline_controlnet.py
│   └── utils/
│       ├── converter.py
│       ├── spec_to_mel.py
│       └── util.py
├── inference.py
├── pyproject.toml
└── requirements/
    └── environment.yaml

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.ckpt
*.pt
*.pyc
*.safetensors

__pycache__/
output/
checkpoints/
train/
configs/

*.wav
*.mp3
*.gif
*.jpg
*.png
*.log
*.ckpt
*.json

*.csv
*.txt
*.bin


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    # Ruff version.
    rev: v0.3.5
    hooks:
      # Run the linter.
      - id: ruff
        args: [ --fix ]
      # Run the formatter.
      - id: ruff-format
  - repo: https://github.com/codespell-project/codespell
    rev: v2.2.1
    hooks:
      - id: codespell
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.3.0
    hooks:
      - id: trailing-whitespace
      - id: check-yaml
      - id: end-of-file-fixer
      - id: requirements-txt-fixer
      - id: fix-encoding-pragma
        args: ["--remove"]
      - id: mixed-line-ending
        args: ["--fix=lf"]


================================================
FILE: LICENSE
================================================
                               Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<p align="center">
<img src='assets/foleycrafter.png' style="text-align: center; width: 134px" >
</p>

<div align="center">

[![arXiv](https://img.shields.io/badge/arXiv-2407.01494-b31b1b.svg)](https://arxiv.org/abs/2407.01494)
[![Project Page](https://img.shields.io/badge/FoleyCrafter-Website-green)](https://foleycrafter.github.io)
<a target="_blank" href="https://huggingface.co/spaces/ymzhang319/FoleyCrafter">
  <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
</a>
[![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/ymzhang319/FoleyCrafter)
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/ymzhang319/FoleyCrafter)

</div>

# FoleyCrafter

Sound effects are the unsung heroes of cinema and gaming, enhancing realism, impact, and emotional depth for an immersive audiovisual experience. **FoleyCrafter** is a video-to-audio generation framework which can produce realistic sound effects semantically relevant and synchronized with videos.

**Your star is our fuel! <img alt="" width="30" src="https://camo.githubusercontent.com/2f4f0d02cdf79dc1ff8d2b053b4410b13bc2e39cbc8a96fcdc6f06538a3d6d2b/68747470733a2f2f656d2d636f6e74656e742e7a6f626a2e6e65742f736f757263652f616e696d617465642d6e6f746f2d636f6c6f722d656d6f6a692f3335362f736d696c696e672d666163652d776974682d6865617274735f31663937302e676966"> We're revving up the engines with it! <img alt="" width="30" src="https://camo.githubusercontent.com/028a75f875b8c3aa1b3c80bbf7dd27973c4bb654fffcf0bdc0b6f1b0674ce481/68747470733a2f2f656d2d636f6e74656e742e7a6f626a2e6e65742f736f757263652f74656c656772616d2f3338362f737061726b6c65735f323732382e77656270">**


[FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds]()

[Yiming Zhang](https://github.com/ymzhang0319),
[Yicheng Gu](https://github.com/VocodexElysium),
[Yanhong Zeng†](https://zengyh1900.github.io/),
[Zhening Xing](https://github.com/LeoXing1996/),
[Yuancheng Wang](https://github.com/HeCheng0625),
[Zhizheng Wu](https://drwuz.com/),
[Kai Chen†](https://chenkai.site/)

(†Corresponding Author)


## What's New
- [ ] A more powerful one :stuck_out_tongue_closed_eyes: .
- [ ] Release training code.
- [x] `2024/07/01` Release the model and code of FoleyCrafter.

## Setup

### Prepare Environment
Use the following command to install dependencies:
```bash
# install conda environment
conda env create -f requirements/environment.yaml
conda activate foleycrafter

# install GIT LFS for checkpoints download
conda install git-lfs
git lfs install
```

### Download Checkpoints
The checkpoints will be downloaded automatically by running `inference.py`.

You can also download manually using following commands.
<li> Download the text-to-audio base model. We use Auffusion</li>

```bash
git clone https://huggingface.co/auffusion/auffusion-full-no-adapter checkpoints/auffusion
```

<li> Download FoleyCrafter</li>

```bash
git clone https://huggingface.co/ymzhang319/FoleyCrafter checkpoints/
```

Put checkpoints as follows:
```
└── checkpoints
    ├── semantic
    │   ├── semantic_adapter.bin
    ├── vocoder
    │   ├── vocoder.pt
    │   ├── config.json
    ├── temporal_adapter.ckpt
    │   │
    └── timestamp_detector.pth.tar
```

## Gradio demo

You can launch the Gradio interface for FoleyCrafter by running the following command:

```bash
python app.py --share
```



## Inference
### Video To Audio Generation
```bash
python inference.py --save_dir=output/sora/
```

Results:
<table class='center'>
<tr>
  <td><p style="text-align: center">Input Video</p></td>
  <td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309262-d7c89984-c567-4ca7-8e2d-8f49d84bda4a.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122032Z&X-Amz-Expires=300&X-Amz-Signature=5b13f216056dedca2705233038dbb22f73023d2c1deaf3b03972d7b91c1bbab5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

 </td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309725-0dfa72a2-1466-46e6-9611-3e1cbff707fe.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122123Z&X-Amz-Expires=300&X-Amz-Signature=314648ed216620b2d926395d34602c70da500eb9e865e839de6907ed1b0d0bd1&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309166-16206bb8-9c5e-4e9d-9d73-bc251e5658fd.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122143Z&X-Amz-Expires=300&X-Amz-Signature=43c3e2c687846eb3ba118237628b747a78c403ea4f21739fe2d423724f7b426c&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309768-90c42af6-0d24-4a05-98d4-64e23467c4bb.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122213Z&X-Amz-Expires=300&X-Amz-Signature=cfed43cd2710bf73b84b6c3ebe8debd1e0b098bdc24a1a14f6531499d01c278e&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309601-e711b7c5-1614-4d39-8b1e-c54e28eec809.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122221Z&X-Amz-Expires=300&X-Amz-Signature=4c4680eb6c541433e4505fb2b5f5a5cc8d3e5708d9f2675a98cbb556cd5d59f5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309802-2db7f130-0c25-45c2-ad4d-bf86c5468b1f.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122243Z&X-Amz-Expires=300&X-Amz-Signature=2c32069318f60f03ee9a3185a7e2833c534ea601a02a5ff025b46c2abbc5b120&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309637-6c2f106d-6b98-41ac-80ba-734636321f8c.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122305Z&X-Amz-Expires=300&X-Amz-Signature=04adb2cd80785a245ce704837ba9932d3646d375c51c53e7b70e6861ec7f6b4a&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309836-77391524-9b31-4602-ad42-0876e0c16794.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122317Z&X-Amz-Expires=300&X-Amz-Signature=3020c12591592a106efcb1aaa22237093737afaa55ede3880d3cbf9cd80b7482&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
</table>

- Temporal Alignment with Visual Cues
```bash
python inference.py \
--temporal_align \
--input=input/avsync \
--save_dir=output/avsync/
```

Results:
<table class='center'>
<tr>
  <td><p style="text-align: center">Ground Truth</p></td>
  <td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310778-bcc0f16d-6d1b-468d-a775-81b8f2d98ea6.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122327Z&X-Amz-Expires=300&X-Amz-Signature=205b8e190a428b3ddee41fe2549080b4f50fd8bb10ef78d650fc05add85ccbab&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310418-8433e05c-8600-4cd6-8a68-ead536159204.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122337Z&X-Amz-Expires=300&X-Amz-Signature=3fc37a305511c1c8b7bdfc9b9b5bd0485fd584400af087939a4c08218ab33538&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310801-3d6fd80d-de6b-4815-ac6a-f81772709e4c.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122349Z&X-Amz-Expires=300&X-Amz-Signature=fc31f139a1f9c7606657fa457f1271a8a44cb39a13454939c030ccdafe2d3068&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310491-dfaf41e7-487e-47ff-8e8a-fe7cb4fb1942.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122356Z&X-Amz-Expires=300&X-Amz-Signature=6353a935194a08bc081fa873e3c6582fb175874d3112f8f8f96614a5e542ef03&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310825-6834f00f-95e8-4a2c-b864-b4fe57801836.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122406Z&X-Amz-Expires=300&X-Amz-Signature=c67b2f113b0db790a8495d1a4ab4c0d230db5f53a9062a497bcca3e57f9600aa&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310543-5a2c363b-623c-4329-be0e-a151e5bb56a6.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122415Z&X-Amz-Expires=300&X-Amz-Signature=d8b0bfc28716e0e03694b3e590aca29450fa7788aa39e748130ee12a15d614e9&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
</table>

### Text-based Video to Audio Generation

- Using Prompt

```bash
# case1
python inference.py \
--input=input/PromptControl/case1/ \
--seed=10201304011203481429 \
--save_dir=output/PromptControl/case1/

python inference.py \
--input=input/PromptControl/case1/ \
--seed=10201304011203481429 \
--prompt='noisy, people talking' \
--save_dir=output/PromptControl/case1_prompt/

# case2
python inference.py \
--input=input/PromptControl/case2/ \
--seed=10021049243103289113 \
--save_dir=output/PromptControl/case2/

python inference.py \
--input=input/PromptControl/case2/ \
--seed=10021049243103289113 \
--prompt='seagulls' \
--save_dir=output/PromptControl/case2_prompt/
```
Results:
<table class='center'>
<tr>
  <td><p style="text-align: center">Generated Audio</p></td>
  <td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
  <td><p style="text-align: center">Without Prompt</p></td>
  <td><p style="text-align: center">Prompt: <b>noisy, people talking</b></p></td>
<tr>
<tr>
  <td>


https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311425-8dd543cb-0df2-441e-b6d0-86048dbeb73d.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122426Z&X-Amz-Expires=300&X-Amz-Signature=b872b8eaf51a5022aee1daf0283d92e53a70e109c6b9f1e6a4da238a3708ea45&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188


</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311493-62a08024-581c-4716-a030-aef194beddc5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122439Z&X-Amz-Expires=300&X-Amz-Signature=647eb0dc32bf7c0d739ccbe875826b1a67f54e0dc84e0be70e0e128ae2fdb73d&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td><p style="text-align: center">Without Prompt</p></td>
  <td><p style="text-align: center">Prompt: <b>seagulls</b></p></td>
<tr>
<tr>
  <td>



https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311538-1f81f91e-efc0-41ed-bdcb-c5c6ff976c5b.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122447Z&X-Amz-Expires=300&X-Amz-Signature=d19761788775893e77e42b9312b4c26cb85aedd7c6dc249eaf68ff1f650e1942&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188



</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311595-695668ed-46a1-47b2-b5fd-3aa4286d695e.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122500Z&X-Amz-Expires=300&X-Amz-Signature=bf7d69ab8c74154ee8ac5682f64ce29a310ea2d0365f620893a45899d62a3f80&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
</table>

- Using Negative Prompt
```bash
# case 3
python inference.py \
--input=input/PromptControl/case3/ \
--seed=10041042941301238011 \
--save_dir=output/PromptControl/case3/

python inference.py \
--input=input/PromptControl/case3/ \
--seed=10041042941301238011 \
--nprompt='river flows' \
--save_dir=output/PromptControl/case3_nprompt/

# case4
python inference.py \
--input=input/PromptControl/case4/ \
--seed=10014024412012338096 \
--save_dir=output/PromptControl/case4/

python inference.py \
--input=input/PromptControl/case4/ \
--seed=10014024412012338096 \
--nprompt='noisy, wind noise' \
--save_dir=output/PromptControl/case4_nprompt/

```
Results:
<table class='center'>
<tr>
  <td><p style="text-align: center">Generated Audio</p></td>
  <td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
  <td><p style="text-align: center">Without Prompt</p></td>
  <td><p style="text-align: center">Negative Prompt: <b>river flows</b></p></td>
<tr>
<tr>
  <td>



https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311656-cdc69cf1-88f8-4861-b888-bdb82358b9c5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122517Z&X-Amz-Expires=300&X-Amz-Signature=1731e655b2f0bb7f4a7af737ee065a01f98fccb1c54ef48ae775ad65ec67eda5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188



</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311702-cd259522-84f4-44cb-862f-c4dcfb57e5c4.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122527Z&X-Amz-Expires=300&X-Amz-Signature=d846ac60df25ff18de1861daeff380b7bc8ca21c04e7dce139ed44abbf9aaa22&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
<tr>
  <td><p style="text-align: center">Without Prompt</p></td>
  <td><p style="text-align: center">Negative Prompt: <b>noisy, wind noise</b></p></td>
<tr>
<tr>
  <td>


https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311785-5ca9c050-a928-4dc2-b620-d843a3ae72f5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122533Z&X-Amz-Expires=300&X-Amz-Signature=151fe4da521f9f48ff245ef5bd7c6964f1dfc652be0fe8de4c151ba59e87d2d6&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188


</td>
  <td>

https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311844-28d6abe3-d5a8-4a7f-9f4d-3cc8411affba.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122544Z&X-Amz-Expires=300&X-Amz-Signature=39a259ca76a12a57d47ad74c8cf92af12af9e3db34084f5f82c79b6a62356e9a&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188

</td>
<tr>
</table>

### Commandline Usage Parameters
```console
options:
  -h, --help            show this help message and exit
  --prompt PROMPT       prompt for audio generation
  --nprompt NPROMPT     negative prompt for audio generation
  --seed SEED           ramdom seed
  --temporal_align TEMPORAL_ALIGN
                        use temporal adapter or not
  --temporal_scale TEMPORAL_SCALE
                        temporal align scale
  --semantic_scale SEMANTIC_SCALE
                        visual content scale
  --input INPUT         input video folder path
  --ckpt CKPT           checkpoints folder path
  --save_dir SAVE_DIR   generation result save path
  --pretrain PRETRAIN   generator checkpoint path
  --device DEVICE
```


## BibTex
```
@misc{zhang2024pia,
  title={FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds},
  author={Yiming Zhang, Yicheng Gu, Yanhong Zeng, Zhening Xing, Yuancheng Wang, Zhizheng Wu, Kai Chen},
  year={2024},
  eprint={2407.01494},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```


## Contact Us

**Yiming Zhang**: zhangyiming@pjlab.org.cn

**YiCheng Gu**: yichenggu@link.cuhk.edu.cn

**Yanhong Zeng**: zengyanhong@pjlab.org.cn

## LICENSE
Please check [LICENSE](./LICENSE) for the part of FoleyCrafter for details.
If you are using it for commercial purposes, please check the license of the [Auffusion](https://github.com/happylittlecat2333/Auffusion).

## Acknowledgements
The code is built upon [Auffusion](https://github.com/happylittlecat2333/Auffusion), [CondFoleyGen](https://github.com/XYPB/CondFoleyGen) and [SpecVQGAN](https://github.com/v-iashin/SpecVQGAN).

We recommend a toolkit for Audio, Music, and Speech Generation [Amphion](https://github.com/open-mmlab/Amphion) :gift_heart:.


================================================
FILE: app.py
================================================
import os
import os.path as osp
import random
from argparse import ArgumentParser
from datetime import datetime

import gradio as gr
import soundfile as sf
import torch
import torchvision
from huggingface_hub import snapshot_download
from moviepy.editor import AudioFileClip, VideoFileClip
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
from foleycrafter.models.onset import torch_utils
from foleycrafter.models.time_detector.model import VideoOnsetNet
from foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram
from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy


os.environ["GRADIO_TEMP_DIR"] = "./tmp"

sample_idx = 0
scheduler_dict = {
    "DDIM": DDIMScheduler,
    "Euler": EulerDiscreteScheduler,
    "PNDM": PNDMScheduler,
}

css = """
.toolbutton {
    margin-buttom: 0em 0em 0em 0em;
    max-width: 2.5em;
    min-width: 2.5em !important;
    height: 2.5em;
}
"""

parser = ArgumentParser()
parser.add_argument("--config", type=str, default="example/config/base.yaml")
parser.add_argument("--server-name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", type=bool, default=False)

parser.add_argument("--save-path", default="samples")
parser.add_argument("--ckpt", type=str, default="checkpoints/")

args = parser.parse_args()


N_PROMPT = ""


class FoleyController:
    def __init__(self):
        # config dirs
        self.basedir = os.getcwd()
        self.model_dir = os.path.join(self.basedir, args.ckpt)
        self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
        self.savedir_sample = os.path.join(self.savedir, "sample")
        os.makedirs(self.savedir, exist_ok=True)

        self.pipeline = None

        self.loaded = False

        self.load_model()

    def load_model(self):
        gr.Info("Start Load Models...")
        print("Start Load Models...")

        # download ckpt
        pretrained_model_name_or_path = "auffusion/auffusion-full-no-adapter"
        if not os.path.isdir(pretrained_model_name_or_path):
            pretrained_model_name_or_path = snapshot_download(
                pretrained_model_name_or_path, local_dir=osp.join(self.model_dir, "auffusion")
            )

        fc_ckpt = "ymzhang319/FoleyCrafter"
        if not os.path.isdir(fc_ckpt):
            fc_ckpt = snapshot_download(fc_ckpt, local_dir=self.model_dir)

        # set model config
        temporal_ckpt_path = osp.join(self.model_dir, "temporal_adapter.ckpt")

        # load vocoder
        vocoder_config_path = osp.join(self.model_dir, "auffusion")
        self.vocoder = Generator.from_pretrained(vocoder_config_path, subfolder="vocoder")

        # load time detector
        time_detector_ckpt = osp.join(osp.join(self.model_dir, "timestamp_detector.pth.tar"))
        time_detector = VideoOnsetNet(False)
        self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)

        self.pipeline = build_foleycrafter()
        ckpt = torch.load(temporal_ckpt_path)

        # load temporal adapter
        if "state_dict" in ckpt.keys():
            ckpt = ckpt["state_dict"]
        load_gligen_ckpt = {}
        for key, value in ckpt.items():
            if key.startswith("module."):
                load_gligen_ckpt[key[len("module.") :]] = value
            else:
                load_gligen_ckpt[key] = value
        m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
        print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")

        self.image_processor = CLIPImageProcessor()
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            "h94/IP-Adapter", subfolder="models/image_encoder"
        )

        self.pipeline.load_ip_adapter(
            fc_ckpt, subfolder="semantic", weight_name="semantic_adapter.bin", image_encoder_folder=None
        )

        gr.Info("Load Finish!")
        print("Load Finish!")
        self.loaded = True

        return "Load"

    def foley(
        self,
        input_video,
        prompt_textbox,
        negative_prompt_textbox,
        ip_adapter_scale,
        temporal_scale,
        sampler_dropdown,
        sample_step_slider,
        cfg_scale_slider,
        seed_textbox,
    ):
        device = "cuda"
        # move to gpu
        self.time_detector = controller.time_detector.to(device)
        self.pipeline = controller.pipeline.to(device)
        self.vocoder = controller.vocoder.to(device)
        self.image_encoder = controller.image_encoder.to(device)
        vision_transform_list = [
            torchvision.transforms.Resize((128, 128)),
            torchvision.transforms.CenterCrop((112, 112)),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
        video_transform = torchvision.transforms.Compose(vision_transform_list)
        # if not self.loaded:
        #     raise gr.Error("Error with loading model")
        generator = torch.Generator()
        if seed_textbox != "":
            torch.manual_seed(int(seed_textbox))
            generator.manual_seed(int(seed_textbox))
        max_frame_nums = 150
        frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
        if duration >= 10:
            duration = 10
        time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(device)
        time_frames = video_transform(time_frames)
        time_frames = {"frames": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
        preds = self.time_detector(time_frames)
        preds = torch.sigmoid(preds)

        # duration
        time_condition = [
            -1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1
            for i in range(int(1024 / 10 * duration))
        ]
        time_condition = time_condition + [-1] * (1024 - len(time_condition))
        # w -> b c h w
        time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)

        # Note that clip need fewer frames
        frames = frames[::10]
        images = self.image_processor(images=frames, return_tensors="pt").to(device)
        image_embeddings = self.image_encoder(**images).image_embeds
        image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
        neg_image_embeddings = torch.zeros_like(image_embeddings)
        image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
        self.pipeline.set_ip_adapter_scale(ip_adapter_scale)
        sample = self.pipeline(
            prompt=prompt_textbox,
            negative_prompt=negative_prompt_textbox,
            ip_adapter_image_embeds=image_embeddings,
            image=time_condition,
            controlnet_conditioning_scale=float(temporal_scale),
            num_inference_steps=sample_step_slider,
            height=256,
            width=1024,
            output_type="pt",
            generator=generator,
        )
        name = "output"
        audio_img = sample.images[0]
        audio = denormalize_spectrogram(audio_img)
        audio = self.vocoder.inference(audio, lengths=160000)[0]
        audio_save_path = osp.join(self.savedir_sample, "audio")
        os.makedirs(audio_save_path, exist_ok=True)
        audio = audio[: int(duration * 16000)]

        save_path = osp.join(audio_save_path, f"{name}.wav")
        sf.write(save_path, audio, 16000)

        audio = AudioFileClip(osp.join(audio_save_path, f"{name}.wav"))
        video = VideoFileClip(input_video)
        audio = audio.subclip(0, duration)
        video.audio = audio
        video = video.subclip(0, duration)
        video.write_videofile(osp.join(self.savedir_sample, f"{name}.mp4"))
        save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4")

        return save_sample_path


controller = FoleyController()
device = "cuda" if torch.cuda.is_available() else "cpu"

with gr.Blocks(css=css) as demo:
    gr.HTML(
        '<h1 style="height: 136px; display: flex; align-items: center; justify-content: space-around;"><span style="height: 100%; width:136px;"><img src="file/assets/foleycrafter.png" alt="logo" style="height: 100%; width:auto; object-fit: contain; margin: 0px 0px; padding: 0px 0px;"></span><strong style="font-size: 36px;">FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</strong></h1>'
    )
    gr.HTML(
        '<p id="authors" style="text-align:center; font-size:24px;"> \
        <a href="https://github.com/ymzhang0319">Yiming Zhang</a><sup>1</sup>,&nbsp \
        <a href="https://github.com/VocodexElysium">Yicheng Gu</a><sup>2</sup>,&nbsp \
        <a href="https://zengyh1900.github.io/">Yanhong Zeng</a><sup>1 †</sup>,&nbsp \
        <a href="https://github.com/LeoXing1996/">Zhening Xing</a><sup>1</sup>,&nbsp \
        <a href="https://github.com/HeCheng0625">Yuancheng Wang</a><sup>2</sup>,&nbsp \
        <a href="https://drwuz.com/">Zhizheng Wu</a><sup>2</sup>,&nbsp \
        <a href="https://chenkai.site/">Kai Chen</a><sup>1 †</sup>\
        <br>\
        <span>\
            <sup>1</sup>Shanghai AI Laboratory &nbsp;&nbsp;&nbsp;\
            <sup>2</sup>Chinese University of Hong Kong, Shenzhen &nbsp;&nbsp;&nbsp;\
            †Corresponding author\
        </span>\
    </p>'
    )
    with gr.Row():
        gr.Markdown(
            "<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a> &ensp;"  # noqa
            "<a href='https://arxiv.org/abs/2407.01494/'>Paper</a> &ensp;"
            "<a href='https://github.com/open-mmlab/foleycrafter'>Code</a> &ensp;"
            "<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>"
        )

    with gr.Column(variant="panel"):
        with gr.Row(equal_height=False):
            with gr.Column():
                with gr.Row():
                    init_img = gr.Video(label="Input Video")
                with gr.Row():
                    prompt_textbox = gr.Textbox(value="", label="Prompt", lines=1)
                with gr.Row():
                    negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)

                with gr.Row():
                    ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1)
                    temporal_scale = gr.Slider(label="Temporal Align Scale", value=0.2, minimum=0.0, maximum=1.0)

                with gr.Accordion("Sampling Settings", open=False):
                    with gr.Row():
                        sampler_dropdown = gr.Dropdown(
                            label="Sampling method",
                            choices=list(scheduler_dict.keys()),
                            value=list(scheduler_dict.keys())[0],
                        )
                        sample_step_slider = gr.Slider(
                            label="Sampling steps", value=25, minimum=10, maximum=100, step=1
                        )
                    cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)

                with gr.Row():
                    seed_textbox = gr.Textbox(label="Seed", value=42)
                    seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
                seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)

                generate_button = gr.Button(value="Generate", variant="primary")

            with gr.Column():
                result_video = gr.Video(label="Generated Audio", interactive=False)
                with gr.Row():
                    gr.Markdown(
                        "<div style='word-spacing: 6px;'><font size='5'><b>Tips</b>: <br> \
                        1. With strong temporal visual cues in input video, you can scale up the <b>Temporal Align Scale</b>. <br>\
                        2. <b>Visual content scale</b> is the level of semantic alignment with visual content.</font></div> \
                    "
                    )

        generate_button.click(
            fn=controller.foley,
            inputs=[
                init_img,
                prompt_textbox,
                negative_prompt_textbox,
                ip_adapter_scale,
                temporal_scale,
                sampler_dropdown,
                sample_step_slider,
                cfg_scale_slider,
                seed_textbox,
            ],
            outputs=[result_video],
        )

        gr.Examples(
            examples=[
                ["examples/gen3/case1.mp4", "", "", 1.0, 0.2, "DDIM", 25, 7.5, 33817921],
                ["examples/gen3/case3.mp4", "", "", 1.0, 0.2, "DDIM", 25, 7.5, 94667578],
                ["examples/gen3/case5.mp4", "", "", 0.75, 0.2, "DDIM", 25, 7.5, 92890876],
                ["examples/gen3/case6.mp4", "", "", 1.0, 0.2, "DDIM", 25, 7.5, 77015909],
            ],
            inputs=[
                init_img,
                prompt_textbox,
                negative_prompt_textbox,
                ip_adapter_scale,
                temporal_scale,
                sampler_dropdown,
                sample_step_slider,
                cfg_scale_slider,
                seed_textbox,
            ],
            cache_examples=True,
            outputs=[result_video],
            fn=controller.foley,
        )

    demo.queue(10)
    demo.launch(
        server_name=args.server_name,
        server_port=args.port,
        share=args.share,
        allowed_paths=["./assets/foleycrafter.png"],
    )


================================================
FILE: foleycrafter/data/__init__.py
================================================
from .dataset import AudioSetStrong, CPU_Unpickler, VGGSound, dynamic_range_compression, get_mel, zero_rank_print
from .video_transforms import (
    CenterCropVideo,
    KineticsRandomCropResizeVideo,
    NormalizeVideo,
    RandomHorizontalFlipVideo,
    TemporalRandomCrop,
    ToTensorVideo,
    UCFCenterCropVideo,
)


__all__ = [
    "zero_rank_print",
    "get_mel",
    "dynamic_range_compression",
    "CPU_Unpickler",
    "AudioSetStrong",
    "VGGSound",
    "UCFCenterCropVideo",
    "KineticsRandomCropResizeVideo",
    "CenterCropVideo",
    "NormalizeVideo",
    "ToTensorVideo",
    "RandomHorizontalFlipVideo",
    "TemporalRandomCrop",
]


================================================
FILE: foleycrafter/data/dataset.py
================================================
import glob
import io
import pickle
import random

import numpy as np
import torch
import torch.distributed as dist
import torchaudio
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset


def zero_rank_print(s):
    if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0):
        print("### " + s, flush=True)


@torch.no_grad()
def get_mel(audio_data, audio_cfg):
    # mel shape: (n_mels, T)
    mel = torchaudio.transforms.MelSpectrogram(
        sample_rate=audio_cfg["sample_rate"],
        n_fft=audio_cfg["window_size"],
        win_length=audio_cfg["window_size"],
        hop_length=audio_cfg["hop_size"],
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm=None,
        onesided=True,
        n_mels=64,
        f_min=audio_cfg["fmin"],
        f_max=audio_cfg["fmax"],
    ).to(audio_data.device)
    mel = mel(audio_data)
    # we use log mel spectrogram as input
    mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
    return mel  # (T, n_mels)


def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
    """
    PARAMS
    ------
    C: compression factor
    """
    return normalize_fun(torch.clamp(x, min=clip_val) * C)


class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == "torch.storage" and name == "_load_from_bytes":
            return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
        else:
            return super().find_class(module, name)


class AudioSetStrong(Dataset):
    # read feature and audio
    def __init__(
        self,
        data_path="data/AudioSetStrong/train/feature",
        video_path="data/AudioSetStrong/train/video",
    ):
        super().__init__()
        self.data_path = data_path
        self.data_list = list(self.data_path)
        self.length = len(self.data_list)
        # get video feature
        self.video_path = video_path
        vision_transform_list = [
            transforms.Resize((128, 128)),
            transforms.CenterCrop((112, 112)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
        self.video_transform = transforms.Compose(vision_transform_list)

    def get_batch(self, idx):
        embeds = self.data_list[idx]
        mel = embeds["mel"]
        save_bsz = mel.shape[0]
        audio_info = embeds["audio_info"]
        text_embeds = embeds["text_embeds"]

        # audio_info['label_list'] = np.array(audio_info['label_list'])
        audio_info_array = np.array(audio_info["label_list"])
        prompts = []
        for i in range(save_bsz):
            prompts.append(", ".join(audio_info_array[i, : audio_info["event_num"][i]].tolist()))

        return mel, audio_info, text_embeds, prompts

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx)
                break
            except Exception:
                zero_rank_print(" >>> load error <<<")
                idx = random.randint(0, self.length - 1)
        sample = {
            "mel": mel,
            "audio_info": audio_info,
            "text_embeds": text_embeds,
            "prompts": prompts,
            "videos": videos,
        }
        return sample


class VGGSound(Dataset):
    # read feature and audio
    def __init__(
        self,
        data_path="data/VGGSound/train/video",
        visual_data_path="data/VGGSound/train/feature",
    ):
        super().__init__()
        self.data_path = data_path
        self.visual_data_path = visual_data_path
        self.embeds_list = glob.glob(f"{self.data_path}/*.pt")
        self.visual_list = glob.glob(f"{self.visual_data_path}/*.pt")
        self.length = len(self.embeds_list)

    def get_batch(self, idx):
        embeds = torch.load(self.embeds_list[idx], map_location="cpu")
        visual_embeds = torch.load(self.visual_list[idx], map_location="cpu")

        # audio_embeds  = embeds['audio_embeds']
        visual_embeds = visual_embeds["visual_embeds"]
        # video_name = embeds["video_name"]
        text = embeds["text"]
        mel = embeds["mel"]

        audio = mel

        return visual_embeds, audio, text

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                visual_embeds, audio, text = self.get_batch(idx)
                break
            except Exception:
                zero_rank_print("load error")
                idx = random.randint(0, self.length - 1)
        sample = {"visual_embeds": visual_embeds, "audio": audio, "text": text}
        return sample


================================================
FILE: foleycrafter/data/video_transforms.py
================================================
import numbers
import random

import torch


def _is_tensor_video_clip(clip):
    if not torch.is_tensor(clip):
        raise TypeError("clip should be Tensor. Got %s" % type(clip))

    if not clip.ndimension() == 4:
        raise ValueError("clip should be 4D. Got %dD" % clip.dim())

    return True


def crop(clip, i, j, h, w):
    """
    Args:
        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
    """
    if len(clip.size()) != 4:
        raise ValueError("clip should be a 4D tensor")
    return clip[..., i : i + h, j : j + w]


def resize(clip, target_size, interpolation_mode):
    if len(target_size) != 2:
        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
    return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)


def resize_scale(clip, target_size, interpolation_mode):
    if len(target_size) != 2:
        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
    _, _, H, W = clip.shape
    scale_ = target_size[0] / min(H, W)
    return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)


def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
    """
    Do spatial cropping and resizing to the video clip
    Args:
        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the cropped region.
        w (int): Width of the cropped region.
        size (tuple(int, int)): height and width of resized clip
    Returns:
        clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
    """
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    clip = crop(clip, i, j, h, w)
    clip = resize(clip, size, interpolation_mode)
    return clip


def center_crop(clip, crop_size):
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    h, w = clip.size(-2), clip.size(-1)
    th, tw = crop_size
    if h < th or w < tw:
        raise ValueError("height and width must be no smaller than crop_size")

    i = int(round((h - th) / 2.0))
    j = int(round((w - tw) / 2.0))
    return crop(clip, i, j, th, tw)


def random_shift_crop(clip):
    """
    Slide along the long edge, with the short edge as crop size
    """
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    h, w = clip.size(-2), clip.size(-1)

    if h <= w:
        # long_edge = w
        short_edge = h
    else:
        # long_edge = h
        short_edge = w

    th, tw = short_edge, short_edge

    i = torch.randint(0, h - th + 1, size=(1,)).item()
    j = torch.randint(0, w - tw + 1, size=(1,)).item()
    return crop(clip, i, j, th, tw)


def to_tensor(clip):
    """
    Convert tensor data type from uint8 to float, divide value by 255.0 and
    permute the dimensions of clip tensor
    Args:
        clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
    Return:
        clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
    """
    _is_tensor_video_clip(clip)
    if not clip.dtype == torch.uint8:
        raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
    # return clip.float().permute(3, 0, 1, 2) / 255.0
    return clip.float() / 255.0


def normalize(clip, mean, std, inplace=False):
    """
    Args:
        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
        mean (tuple): pixel RGB mean. Size is (3)
        std (tuple): pixel standard deviation. Size is (3)
    Returns:
        normalized clip (torch.tensor): Size is (T, C, H, W)
    """
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    if not inplace:
        clip = clip.clone()
    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
    print(mean)
    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
    return clip


def hflip(clip):
    """
    Args:
        clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
    Returns:
        flipped clip (torch.tensor): Size is (T, C, H, W)
    """
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    return clip.flip(-1)


class RandomCropVideo:
    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
        Returns:
            torch.tensor: randomly cropped video clip.
                size is (T, C, OH, OW)
        """
        i, j, h, w = self.get_params(clip)
        return crop(clip, i, j, h, w)

    def get_params(self, clip):
        h, w = clip.shape[-2:]
        th, tw = self.size

        if h < th or w < tw:
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")

        if w == tw and h == th:
            return 0, 0, h, w

        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()

        return i, j, th, tw

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"


class UCFCenterCropVideo:
    def __init__(
        self,
        size,
        interpolation_mode="bilinear",
    ):
        if isinstance(size, tuple):
            if len(size) != 2:
                raise ValueError(f"size should be tuple (height, width), instead got {size}")
            self.size = size
        else:
            self.size = (size, size)

        self.interpolation_mode = interpolation_mode

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
        Returns:
            torch.tensor: scale resized / center cropped video clip.
                size is (T, C, crop_size, crop_size)
        """
        clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
        clip_center_crop = center_crop(clip_resize, self.size)
        return clip_center_crop

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"


class KineticsRandomCropResizeVideo:
    """
    Slide along the long edge, with the short edge as crop size. And resie to the desired size.
    """

    def __init__(
        self,
        size,
        interpolation_mode="bilinear",
    ):
        if isinstance(size, tuple):
            if len(size) != 2:
                raise ValueError(f"size should be tuple (height, width), instead got {size}")
            self.size = size
        else:
            self.size = (size, size)

        self.interpolation_mode = interpolation_mode

    def __call__(self, clip):
        clip_random_crop = random_shift_crop(clip)
        clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
        return clip_resize


class CenterCropVideo:
    def __init__(
        self,
        size,
        interpolation_mode="bilinear",
    ):
        if isinstance(size, tuple):
            if len(size) != 2:
                raise ValueError(f"size should be tuple (height, width), instead got {size}")
            self.size = size
        else:
            self.size = (size, size)

        self.interpolation_mode = interpolation_mode

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
        Returns:
            torch.tensor: center cropped video clip.
                size is (T, C, crop_size, crop_size)
        """
        clip_center_crop = center_crop(clip, self.size)
        return clip_center_crop

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"


class NormalizeVideo:
    """
    Normalize the video clip by mean subtraction and division by standard deviation
    Args:
        mean (3-tuple): pixel RGB mean
        std (3-tuple): pixel RGB standard deviation
        inplace (boolean): whether do in-place normalization
    """

    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
        """
        return normalize(clip, self.mean, self.std, self.inplace)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"


class ToTensorVideo:
    """
    Convert tensor data type from uint8 to float, divide value by 255.0 and
    permute the dimensions of clip tensor
    """

    def __init__(self):
        pass

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
        Return:
            clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
        """
        return to_tensor(clip)

    def __repr__(self) -> str:
        return self.__class__.__name__


class RandomHorizontalFlipVideo:
    """
    Flip the video clip along the horizontal direction with a given probability
    Args:
        p (float): probability of the clip being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor): Size is (T, C, H, W)
        Return:
            clip (torch.tensor): Size is (T, C, H, W)
        """
        if random.random() < self.p:
            clip = hflip(clip)
        return clip

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"


#  ------------------------------------------------------------
#  ---------------------  Sampling  ---------------------------
#  ------------------------------------------------------------
class TemporalRandomCrop(object):
    """Temporally crop the given frame indices at a random location.

    Args:
            size (int): Desired length of frames will be seen in the model.
    """

    def __init__(self, size):
        self.size = size

    def __call__(self, total_frames):
        rand_end = max(0, total_frames - self.size - 1)
        begin_index = random.randint(0, rand_end)
        end_index = min(begin_index + self.size, total_frames)
        return begin_index, end_index


if __name__ == "__main__":
    import os

    import numpy as np
    import torchvision.io as io
    from torchvision import transforms
    from torchvision.utils import save_image

    vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")

    trans = transforms.Compose(
        [
            ToTensorVideo(),
            RandomHorizontalFlipVideo(),
            UCFCenterCropVideo(512),
            # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ]
    )

    target_video_len = 32
    frame_interval = 1
    total_frames = len(vframes)
    print(total_frames)

    temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)

    # Sampling video frames
    start_frame_ind, end_frame_ind = temporal_sample(total_frames)
    # print(start_frame_ind)
    # print(end_frame_ind)
    assert end_frame_ind - start_frame_ind >= target_video_len
    frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)

    select_vframes = vframes[frame_indice]

    select_vframes_trans = trans(select_vframes)

    select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)

    io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)

    for i in range(target_video_len):
        save_image(
            select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
        )


================================================
FILE: foleycrafter/models/adapters/attention_processor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers.utils import logging


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class IPAttnProcessor(nn.Module):
    r"""
    Attention processor for IP-Adapater.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # for ip-adapter
        ip_key = self.to_k_ip(ip_hidden_states)
        ip_value = self.to_v_ip(ip_hidden_states)

        ip_key = attn.head_to_batch_dim(ip_key)
        ip_value = attn.head_to_batch_dim(ip_value)

        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
        self.attn_map = ip_attention_probs
        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

        hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class AttnProcessor2_0WithProjection(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.before_proj_size = 1024
        self.after_proj_size = 768
        self.visual_proj = nn.Linear(self.before_proj_size, self.after_proj_size)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states
        # encoder_hidden_states = self.visual_proj(encoder_hidden_states)

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class IPAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # for ip-adapter
        ip_key = self.to_k_ip(ip_hidden_states)
        ip_value = self.to_v_ip(ip_hidden_states)

        ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        ip_hidden_states = F.scaled_dot_product_attention(
            query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
        )
        with torch.no_grad():
            self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
            # print(self.attn_map.shape)

        ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        ip_hidden_states = ip_hidden_states.to(query.dtype)

        hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


## for controlnet
class CNAttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(self, num_tokens=4):
        self.num_tokens = num_tokens

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class CNAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self, num_tokens=4):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.num_tokens = num_tokens

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states = encoder_hidden_states[:, :end_pos]  # only use text
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


================================================
FILE: foleycrafter/models/adapters/ip_adapter.py
================================================
import torch
import torch.nn as nn


class IPAdapter(torch.nn.Module):
    """IP-Adapter"""

    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
        super().__init__()
        self.unet = unet
        self.image_proj_model = image_proj_model
        self.adapter_modules = adapter_modules

        if ckpt_path is not None:
            self.load_from_checkpoint(ckpt_path)

    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
        ip_tokens = self.image_proj_model(image_embeds)
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        # Predict the noise residual
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        return noise_pred

    def load_from_checkpoint(self, ckpt_path: str):
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for image_proj_model and adapter_modules
        self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")


class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class MLPProjModel(torch.nn.Module):
    """SD model with image prompt"""

    def zero_initialize(module):
        for param in module.parameters():
            param.data.zero_()

    def zero_initialize_last_layer(module):
        last_layer = None
        for module_name, layer in module.named_modules():
            if isinstance(layer, torch.nn.Linear):
                last_layer = layer

        if last_layer is not None:
            last_layer.weight.data.zero_()
            last_layer.bias.data.zero_()

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
        super().__init__()

        self.proj = torch.nn.Sequential(
            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
            torch.nn.GELU(),
            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
            torch.nn.LayerNorm(cross_attention_dim),
        )
        # zero initialize the last layer
        # self.zero_initialize_last_layer()

    def forward(self, image_embeds):
        clip_extra_context_tokens = self.proj(image_embeds)
        return clip_extra_context_tokens


class V2AMapperMLP(torch.nn.Module):
    def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4):
        super().__init__()
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult),
            torch.nn.GELU(),
            torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim),
            torch.nn.LayerNorm(cross_attention_dim),
        )

    def forward(self, image_embeds):
        clip_extra_context_tokens = self.proj(image_embeds)
        return clip_extra_context_tokens


class TimeProjModel(torch.nn.Module):
    def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums: int = 64):
        super().__init__()
        self.positive_len = positive_len
        self.out_dim = out_dim

        self.position_dim = frame_nums

        if isinstance(out_dim, tuple):
            out_dim = out_dim[0]

        if feature_type == "text-only":
            self.linears = nn.Sequential(
                nn.Linear(self.positive_len + self.position_dim, 512),
                nn.SiLU(),
                nn.Linear(512, 512),
                nn.SiLU(),
                nn.Linear(512, out_dim),
            )
            self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))

        elif feature_type == "text-image":
            self.linears_text = nn.Sequential(
                nn.Linear(self.positive_len + self.position_dim, 512),
                nn.SiLU(),
                nn.Linear(512, 512),
                nn.SiLU(),
                nn.Linear(512, out_dim),
            )
            self.linears_image = nn.Sequential(
                nn.Linear(self.positive_len + self.position_dim, 512),
                nn.SiLU(),
                nn.Linear(512, 512),
                nn.SiLU(),
                nn.Linear(512, out_dim),
            )
            self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
            self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))

        # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))

    def forward(
        self,
        boxes,
        masks,
        positive_embeddings=None,
    ):
        masks = masks.unsqueeze(-1)

        # # embedding position (it may includes padding as placeholder)
        # xyxy_embedding = self.fourier_embedder(boxes)  # B*N*4 -> B*N*C

        # # learnable null embedding
        # xyxy_null = self.null_position_feature.view(1, 1, -1)

        # # replace padding with learnable null embedding
        # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null

        time_embeds = boxes

        # positionet with text only information
        if positive_embeddings is not None:
            # learnable null embedding
            positive_null = self.null_positive_feature.view(1, 1, -1)

            # replace padding with learnable null embedding
            positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null

            objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1))

        # positionet with text and image information
        else:
            raise NotImplementedError

        return objs


================================================
FILE: foleycrafter/models/adapters/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py

import math

import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange


# FFN
def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


def reshape_tensor(x, heads):
    bs, length, width = x.shape
    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
    x = x.view(bs, length, heads, -1)
    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
    x = x.transpose(1, 2)
    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
    x = x.reshape(bs, heads, length, -1)
    return x


class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D)
        """
        x = self.norm1(x)
        latents = self.norm2(latents)

        b, l, _ = latents.shape

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q = reshape_tensor(q, self.heads)
        k = reshape_tensor(k, self.heads)
        v = reshape_tensor(v, self.heads)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v

        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)

        return self.to_out(out)


class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)

        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)


================================================
FILE: foleycrafter/models/adapters/transformer.py
================================================
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.utils.checkpoint


class Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0):
        super().__init__()
        self.embed_dim = hidden_size
        self.num_heads = num_attention_heads
        self.head_dim = attention_head_dim

        self.scale = self.head_dim**-0.5
        self.dropout = attention_dropout

        self.inner_dim = self.head_dim * self.num_heads

        self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
        self.out_proj = nn.Linear(self.inner_dim, self.embed_dim)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        bsz, tgt_len, embed_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        # apply the causal_attention_mask first
        if causal_attention_mask is not None:
            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
                    f" {causal_attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if output_attentions:
            # this operation is a bit akward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped


class MLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, mult=4):
        super().__init__()
        self.activation_fn = nn.SiLU()
        self.fc1 = nn.Linear(hidden_size, intermediate_size * mult)
        self.fc2 = nn.Linear(intermediate_size * mult, hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Transformer(nn.Module):
    def __init__(self, depth=12):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)])

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None,
        causal_attention_mask: torch.Tensor = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        for layer in self.layers:
            hidden_states = layer(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                causal_attention_mask=causal_attention_mask,
                output_attentions=output_attentions,
            )

        return hidden_states


class TransformerBlock(nn.Module):
    def __init__(
        self,
        hidden_size=512,
        num_attention_heads=12,
        attention_head_dim=64,
        attention_dropout=0.0,
        dropout=0.0,
        eps=1e-5,
    ):
        super().__init__()
        self.embed_dim = hidden_size
        self.self_attn = Attention(
            hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
        )
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
        self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None,
        causal_attention_mask: torch.Tensor = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs[0]


class DiffusionTransformerBlock(nn.Module):
    def __init__(
        self,
        hidden_size=512,
        num_attention_heads=12,
        attention_head_dim=64,
        attention_dropout=0.0,
        dropout=0.0,
        eps=1e-5,
    ):
        super().__init__()
        self.embed_dim = hidden_size
        self.self_attn = Attention(
            hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
        )
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
        self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
        self.output_token = nn.Parameter(torch.randn(1, hidden_size))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor = None,
        causal_attention_mask: torch.Tensor = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)
        hidden_states = torch.cat([output_token, hidden_states], dim=1)
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs[0][:, 0:1, ...]


class V2AMapperMLP(nn.Module):
    def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
        super().__init__()
        self.linear = nn.Linear(input_dim, input_dim * expansion_rate)
        self.silu = nn.SiLU()
        self.layer_norm = nn.LayerNorm(input_dim * expansion_rate)
        self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim)

    def forward(self, x):
        x = self.linear(x)
        x = self.silu(x)
        x = self.layer_norm(x)
        x = self.linear2(x)

        return x


class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

        self.zero_initialize_last_layer()

    def zero_initialize_last_layer(module):
        last_layer = None
        for module_name, layer in module.named_modules():
            if isinstance(layer, torch.nn.Linear):
                last_layer = layer

        if last_layer is not None:
            last_layer.weight.data.zero_()
            last_layer.bias.data.zero_()

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class VisionAudioAdapter(torch.nn.Module):
    def __init__(
        self,
        embedding_size=768,
        expand_dim=4,
        token_num=4,
    ):
        super().__init__()

        self.mapper = V2AMapperMLP(
            embedding_size,
            embedding_size,
            expansion_rate=expand_dim,
        )

        self.proj = ImageProjModel(
            cross_attention_dim=embedding_size,
            clip_embeddings_dim=embedding_size,
            clip_extra_context_tokens=token_num,
        )

    def forward(self, image_embeds):
        image_embeds = self.mapper(image_embeds)
        image_embeds = self.proj(image_embeds)
        return image_embeds


================================================
FILE: foleycrafter/models/adapters/utils.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image


attn_maps = {}


def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook


def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split(".")[-1].startswith("attn2"):
            module.register_forward_hook(hook_fn(name))

    return unet


def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1, 0)
    temp_size = None

    for i in range(0, 5):
        scale = 2**i
        if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[1] * 64:
            temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode="bilinear", align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map


def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
    idx = 0 if instance_or_negative else 1
    net_attn_maps = []

    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
        attn_map = upscale(attn_map, image_size)
        net_attn_maps.append(attn_map)

    net_attn_maps = torch.mean(torch.stack(net_attn_maps, dim=0), dim=0)

    return net_attn_maps


def attnmaps2images(net_attn_maps):
    # total_attn_scores = 0
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        # total_attn_scores += attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        # print("norm: ", normalized_attn_map.shape)
        image = Image.fromarray(normalized_attn_map)

        # image = fix_save_attn_map(attn_map)
        images.append(image)

    # print(total_attn_scores)
    return images


def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")


================================================
FILE: foleycrafter/models/auffusion/attention.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import nn

from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.utils.torch_utils import maybe_allow_in_graph
from foleycrafter.models.auffusion.attention_processor import Attention


def _chunked_feed_forward(
    ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
    # "feed_forward_chunk_size" can be used to save memory
    if hidden_states.shape[chunk_dim] % chunk_size != 0:
        raise ValueError(
            f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
        )

    num_chunks = hidden_states.shape[chunk_dim] // chunk_size
    if lora_scale is None:
        ff_output = torch.cat(
            [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
            dim=chunk_dim,
        )
    else:
        # TODO(Patrick): LoRA scale can be removed once PEFT refactor is complete
        ff_output = torch.cat(
            [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
            dim=chunk_dim,
        )

    return ff_output


@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
    r"""
    A gated self-attention dense layer that combines visual features and object features.

    Parameters:
        query_dim (`int`): The number of channels in the query.
        context_dim (`int`): The number of channels in the context.
        n_heads (`int`): The number of heads to use for attention.
        d_head (`int`): The number of channels in each head.
    """

    def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
        super().__init__()

        # we need a linear projection since we need cat visual feature and obj feature
        self.linear = nn.Linear(context_dim, query_dim)

        self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
        self.ff = FeedForward(query_dim, activation_fn="geglu")

        self.norm1 = nn.LayerNorm(query_dim)
        self.norm2 = nn.LayerNorm(query_dim)

        self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
        self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))

        self.enabled = True

    def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
        if not self.enabled:
            return x

        n_visual = x.shape[1]
        objs = self.linear(objs)

        x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
        x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))

        return x


@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
    r"""
    A basic Transformer block.

    Parameters:
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm (:
            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:
            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the attention computation to float32. This is useful for mixed precision training.
        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
            Whether to use learnable elementwise affine parameters for normalization.
        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
            The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
        final_dropout (`bool` *optional*, defaults to False):
            Whether to apply a final dropout after the last feed-forward layer.
        attention_type (`str`, *optional*, defaults to `"default"`):
            The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
        positional_embeddings (`str`, *optional*, defaults to `None`):
            The type of positional embeddings to apply to.
        num_positional_embeddings (`int`, *optional*, defaults to `None`):
            The maximum number of positional embeddings to apply.
    """

    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        dropout=0.0,
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
        only_cross_attention: bool = False,
        double_self_attention: bool = False,
        upcast_attention: bool = False,
        norm_elementwise_affine: bool = True,
        norm_type: str = "layer_norm",  # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
        norm_eps: float = 1e-5,
        final_dropout: bool = False,
        attention_type: str = "default",
        positional_embeddings: Optional[str] = None,
        num_positional_embeddings: Optional[int] = None,
        ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
        ada_norm_bias: Optional[int] = None,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
        attention_out_bias: bool = True,
    ):
        super().__init__()
        self.only_cross_attention = only_cross_attention

        self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
        self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
        self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
        self.use_layer_norm = norm_type == "layer_norm"
        self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"

        if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
            raise ValueError(
                f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
                f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
            )

        if positional_embeddings and (num_positional_embeddings is None):
            raise ValueError(
                "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
            )

        if positional_embeddings == "sinusoidal":
            self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
        else:
            self.pos_embed = None

        # Define 3 blocks. Each block has its own normalization layer.
        # 1. Self-Attn
        if self.use_ada_layer_norm:
            self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
        elif self.use_ada_layer_norm_zero:
            self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
        elif self.use_ada_layer_norm_continuous:
            self.norm1 = AdaLayerNormContinuous(
                dim,
                ada_norm_continous_conditioning_embedding_dim,
                norm_elementwise_affine,
                norm_eps,
                ada_norm_bias,
                "rms_norm",
            )
        else:
            self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=cross_attention_dim if (only_cross_attention and not double_self_attention) else None,
            upcast_attention=upcast_attention,
            out_bias=attention_out_bias,
        )

        # 2. Cross-Attn
        if cross_attention_dim is not None or double_self_attention:
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            if self.use_ada_layer_norm:
                self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
            elif self.use_ada_layer_norm_continuous:
                self.norm2 = AdaLayerNormContinuous(
                    dim,
                    ada_norm_continous_conditioning_embedding_dim,
                    norm_elementwise_affine,
                    norm_eps,
                    ada_norm_bias,
                    "rms_norm",
                )
            else:
                self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

            self.attn2 = Attention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
                out_bias=attention_out_bias,
            )  # is self-attn if encoder_hidden_states is none
        else:
            self.norm2 = None
            self.attn2 = None

        # 3. Feed-forward
        if self.use_ada_layer_norm_continuous:
            self.norm3 = AdaLayerNormContinuous(
                dim,
                ada_norm_continous_conditioning_embedding_dim,
                norm_elementwise_affine,
                norm_eps,
                ada_norm_bias,
                "layer_norm",
            )
        elif not self.use_ada_layer_norm_single:
            self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            inner_dim=ff_inner_dim,
            bias=ff_bias,
        )

        # 4. Fuser
        if attention_type == "gated" or attention_type == "gated-text-image":
            self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)

        # 5. Scale-shift for PixArt-Alpha.
        if self.use_ada_layer_norm_single:
            self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        class_labels: Optional[torch.LongTensor] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    ) -> torch.FloatTensor:
        # Notice that normalization is always applied before the real computation in the following blocks.
        # 0. Self-Attention
        batch_size = hidden_states.shape[0]

        if self.use_ada_layer_norm:
            norm_hidden_states = self.norm1(hidden_states, timestep)
        elif self.use_ada_layer_norm_zero:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
            )
        elif self.use_layer_norm:
            norm_hidden_states = self.norm1(hidden_states)
        elif self.use_ada_layer_norm_continuous:
            norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
        elif self.use_ada_layer_norm_single:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
            ).chunk(6, dim=1)
            norm_hidden_states = self.norm1(hidden_states)
            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
            norm_hidden_states = norm_hidden_states.squeeze(1)
        else:
            raise ValueError("Incorrect norm used")

        if self.pos_embed is not None:
            norm_hidden_states = self.pos_embed(norm_hidden_states)

        # 1. Retrieve lora scale.
        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

        # 2. Prepare GLIGEN inputs
        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
        gligen_kwargs = cross_attention_kwargs.pop("gligen", None)

        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
        if self.use_ada_layer_norm_zero:
            attn_output = gate_msa.unsqueeze(1) * attn_output
        elif self.use_ada_layer_norm_single:
            attn_output = gate_msa * attn_output

        hidden_states = attn_output + hidden_states
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)

        # 2.5 GLIGEN Control
        if gligen_kwargs is not None:
            hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])

        # 3. Cross-Attention
        if self.attn2 is not None:
            if self.use_ada_layer_norm:
                norm_hidden_states = self.norm2(hidden_states, timestep)
            elif self.use_ada_layer_norm_zero or self.use_layer_norm:
                norm_hidden_states = self.norm2(hidden_states)
            elif self.use_ada_layer_norm_single:
                # For PixArt norm2 isn't applied here:
                # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
                norm_hidden_states = hidden_states
            elif self.use_ada_layer_norm_continuous:
                norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
            else:
                raise ValueError("Incorrect norm")

            if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
                norm_hidden_states = self.pos_embed(norm_hidden_states)

            attn_output = self.attn2(
                norm_hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                **cross_attention_kwargs,
            )
            hidden_states = attn_output + hidden_states

        # 4. Feed-forward
        if self.use_ada_layer_norm_continuous:
            norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
        elif not self.use_ada_layer_norm_single:
            norm_hidden_states = self.norm3(hidden_states)

        if self.use_ada_layer_norm_zero:
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

        if self.use_ada_layer_norm_single:
            norm_hidden_states = self.norm2(hidden_states)
            norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

        if self._chunk_size is not None:
            # "feed_forward_chunk_size" can be used to save memory
            ff_output = _chunked_feed_forward(
                self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
            )
        else:
            ff_output = self.ff(norm_hidden_states, scale=lora_scale)

        if self.use_ada_layer_norm_zero:
            ff_output = gate_mlp.unsqueeze(1) * ff_output
        elif self.use_ada_layer_norm_single:
            ff_output = gate_mlp * ff_output

        hidden_states = ff_output + hidden_states
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)

        return hidden_states


@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
    r"""
    A basic Transformer block for video like data.

    Parameters:
        dim (`int`): The number of channels in the input and output.
        time_mix_inner_dim (`int`): The number of channels for temporal attention.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
    """

    def __init__(
        self,
        dim: int,
        time_mix_inner_dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        cross_attention_dim: Optional[int] = None,
    ):
        super().__init__()
        self.is_res = dim == time_mix_inner_dim

        self.norm_in = nn.LayerNorm(dim)

        # Define 3 blocks. Each block has its own normalization layer.
        # 1. Self-Attn
        self.norm_in = nn.LayerNorm(dim)
        self.ff_in = FeedForward(
            dim,
            dim_out=time_mix_inner_dim,
            activation_fn="geglu",
        )

        self.norm1 = nn.LayerNorm(time_mix_inner_dim)
        self.attn1 = Attention(
            query_dim=time_mix_inner_dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            cross_attention_dim=None,
        )

        # 2. Cross-Attn
        if cross_attention_dim is not None:
            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
            # the second cross attention block.
            self.norm2 = nn.LayerNorm(time_mix_inner_dim)
            self.attn2 = Attention(
                query_dim=time_mix_inner_dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
            )  # is self-attn if encoder_hidden_states is none
        else:
            self.norm2 = None
            self.attn2 = None

        # 3. Feed-forward
        self.norm3 = nn.LayerNorm(time_mix_inner_dim)
        self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = None

    def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
        self._chunk_dim = 1

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        num_frames: int,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:
        # Notice that normalization is always applied before the real computation in the following blocks.
        # 0. Self-Attention
        batch_size = hidden_states.shape[0]

        batch_frames, seq_length, channels = hidden_states.shape
        batch_size = batch_frames // num_frames

        hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
        hidden_states = hidden_states.permute(0, 2, 1, 3)
        hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)

        residual = hidden_states
        hidden_states = self.norm_in(hidden_states)

        if self._chunk_size is not None:
            hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
        else:
            hidden_states = self.ff_in(hidden_states)

        if self.is_res:
            hidden_states = hidden_states + residual

        norm_hidden_states = self.norm1(hidden_states)
        attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
        hidden_states = attn_output + hidden_states

        # 3. Cross-Attention
        if self.attn2 is not None:
            norm_hidden_states = self.norm2(hidden_states)
            attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
            hidden_states = attn_output + hidden_states

        # 4. Feed-forward
        norm_hidden_states = self.norm3(hidden_states)

        if self._chunk_size is not None:
            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
        else:
            ff_output = self.ff(norm_hidden_states)

        if self.is_res:
            hidden_states = ff_output + hidden_states
        else:
            hidden_states = ff_output

        hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
        hidden_states = hidden_states.permute(0, 2, 1, 3)
        hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)

        return hidden_states


class SkipFFTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        kv_input_dim: int,
        kv_input_dim_proj_use_bias: bool,
        dropout=0.0,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        attention_out_bias: bool = True,
    ):
        super().__init__()
        if kv_input_dim != dim:
            self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
        else:
            self.kv_mapper = None

        self.norm1 = RMSNorm(dim, 1e-06)

        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=cross_attention_dim,
            out_bias=attention_out_bias,
        )

        self.norm2 = RMSNorm(dim, 1e-06)

        self.attn2 = Attention(
            query_dim=dim,
            cross_attention_dim=cross_attention_dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            out_bias=attention_out_bias,
        )

    def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
        cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}

        if self.kv_mapper is not None:
            encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))

        norm_hidden_states = self.norm1(hidden_states)

        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            **cross_attention_kwargs,
        )

        hidden_states = attn_output + hidden_states

        norm_hidden_states = self.norm2(hidden_states)

        attn_output = self.attn2(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            **cross_attention_kwargs,
        )

        hidden_states = attn_output + hidden_states

        return hidden_states


class FeedForward(nn.Module):
    r"""
    A feed-forward layer.

    Parameters:
        dim (`int`): The number of channels in the input.
        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
        bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
    """

    def __init__(
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
        final_dropout: bool = False,
        inner_dim=None,
        bias: bool = True,
    ):
        super().__init__()
        if inner_dim is None:
            inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim
        linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim, bias=bias)
        if activation_fn == "gelu-approximate":
            act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim, bias=bias)
        elif activation_fn == "geglu-approximate":
            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)

        self.net = nn.ModuleList([])
        # project in
        self.net.append(act_fn)
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))

    def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
        compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
        for module in self.net:
            if isinstance(module, compatible_cls):
                hidden_states = module(hidden_states, scale)
            else:
                hidden_states = module(hidden_states)
        return hidden_states


================================================
FILE: foleycrafter/models/auffusion/attention_processor.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from importlib import import_module
from typing import Callable, List, Optional, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import maybe_allow_in_graph


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None


@maybe_allow_in_graph
class Attention(nn.Module):
    r"""
    A cross attention layer.

    Parameters:
        query_dim (`int`):
            The number of channels in the query.
        cross_attention_dim (`int`, *optional*):
            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
        heads (`int`,  *optional*, defaults to 8):
            The number of heads to use for multi-head attention.
        dim_head (`int`,  *optional*, defaults to 64):
            The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0):
            The dropout probability to use.
        bias (`bool`, *optional*, defaults to False):
            Set to `True` for the query, key, and value linear layers to contain a bias parameter.
        upcast_attention (`bool`, *optional*, defaults to False):
            Set to `True` to upcast the attention computation to `float32`.
        upcast_softmax (`bool`, *optional*, defaults to False):
            Set to `True` to upcast the softmax computation to `float32`.
        cross_attention_norm (`str`, *optional*, defaults to `None`):
            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
        cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups to use for the group norm in the cross attention.
        added_kv_proj_dim (`int`, *optional*, defaults to `None`):
            The number of channels to use for the added key and value projections. If `None`, no projection is used.
        norm_num_groups (`int`, *optional*, defaults to `None`):
            The number of groups to use for the group norm in the attention.
        spatial_norm_dim (`int`, *optional*, defaults to `None`):
            The number of channels to use for the spatial normalization.
        out_bias (`bool`, *optional*, defaults to `True`):
            Set to `True` to use a bias in the output linear layer.
        scale_qk (`bool`, *optional*, defaults to `True`):
            Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
        only_cross_attention (`bool`, *optional*, defaults to `False`):
            Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
            `added_kv_proj_dim` is not `None`.
        eps (`float`, *optional*, defaults to 1e-5):
            An additional value added to the denominator in group normalization that is used for numerical stability.
        rescale_output_factor (`float`, *optional*, defaults to 1.0):
            A factor to rescale the output by dividing it with this value.
        residual_connection (`bool`, *optional*, defaults to `False`):
            Set to `True` to add the residual connection to the output.
        _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
            Set to `True` if the attention block is loaded from a deprecated state dict.
        processor (`AttnProcessor`, *optional*, defaults to `None`):
            The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
            `AttnProcessor` otherwise.
    """

    def __init__(
        self,
        query_dim: int,
        cross_attention_dim: Optional[int] = None,
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        bias: bool = False,
        upcast_attention: bool = False,
        upcast_softmax: bool = False,
        cross_attention_norm: Optional[str] = None,
        cross_attention_norm_num_groups: int = 32,
        added_kv_proj_dim: Optional[int] = None,
        norm_num_groups: Optional[int] = None,
        spatial_norm_dim: Optional[int] = None,
        out_bias: bool = True,
        scale_qk: bool = True,
        only_cross_attention: bool = False,
        eps: float = 1e-5,
        rescale_output_factor: float = 1.0,
        residual_connection: bool = False,
        _from_deprecated_attn_block: bool = False,
        processor: Optional["AttnProcessor"] = None,
        out_dim: int = None,
    ):
        super().__init__()
        self.inner_dim = out_dim if out_dim is not None else dim_head * heads
        self.query_dim = query_dim
        self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
        self.upcast_attention = upcast_attention
        self.upcast_softmax = upcast_softmax
        self.rescale_output_factor = rescale_output_factor
        self.residual_connection = residual_connection
        self.dropout = dropout
        self.fused_projections = False
        self.out_dim = out_dim if out_dim is not None else query_dim

        # we make use of this private variable to know whether this class is loaded
        # with an deprecated state dict so that we can convert it on the fly
        self._from_deprecated_attn_block = _from_deprecated_attn_block

        self.scale_qk = scale_qk
        self.scale = dim_head**-0.5 if self.scale_qk else 1.0

        self.heads = out_dim // dim_head if out_dim is not None else heads
        # for slice_size > 0 the attention score computation
        # is split across the batch axis to save memory
        # You can set slice_size with `set_attention_slice`
        self.sliceable_head_dim = heads

        self.added_kv_proj_dim = added_kv_proj_dim
        self.only_cross_attention = only_cross_attention

        if self.added_kv_proj_dim is None and self.only_cross_attention:
            raise ValueError(
                "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
            )

        if norm_num_groups is not None:
            self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
        else:
            self.group_norm = None

        if spatial_norm_dim is not None:
            self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
        else:
            self.spatial_norm = None

        if cross_attention_norm is None:
            self.norm_cross = None
        elif cross_attention_norm == "layer_norm":
            self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
        elif cross_attention_norm == "group_norm":
            if self.added_kv_proj_dim is not None:
                # The given `encoder_hidden_states` are initially of shape
                # (batch_size, seq_len, added_kv_proj_dim) before being projected
                # to (batch_size, seq_len, cross_attention_dim). The norm is applied
                # before the projection, so we need to use `added_kv_proj_dim` as
                # the number of channels for the group norm.
                norm_cross_num_channels = added_kv_proj_dim
            else:
                norm_cross_num_channels = self.cross_attention_dim

            self.norm_cross = nn.GroupNorm(
                num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
            )
        else:
            raise ValueError(
                f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
            )

        if USE_PEFT_BACKEND:
            linear_cls = nn.Linear
        else:
            linear_cls = LoRACompatibleLinear

        self.linear_cls = linear_cls
        self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)

        if not self.only_cross_attention:
            # only relevant for the `AddedKVProcessor` classes
            self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
            self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
        else:
            self.to_k = None
            self.to_v = None

        if self.added_kv_proj_dim is not None:
            self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
            self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)

        self.to_out = nn.ModuleList([])
        self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
        self.to_out.append(nn.Dropout(dropout))

        # set attention processor
        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
        if processor is None:
            processor = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
            )
        self.set_processor(processor)

    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ) -> None:
        r"""
        Set whether to use memory efficient attention from `xformers` or not.

        Args:
            use_memory_efficient_attention_xformers (`bool`):
                Whether to use memory efficient attention from `xformers` or not.
            attention_op (`Callable`, *optional*):
                The attention operation to use. Defaults to `None` which uses the default attention operation from
                `xformers`.
        """
        is_lora = hasattr(self, "processor") and isinstance(
            self.processor,
            LORA_ATTENTION_PROCESSORS,
        )
        is_custom_diffusion = hasattr(self, "processor") and isinstance(
            self.processor,
            (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
        )
        is_added_kv_processor = hasattr(self, "processor") and isinstance(
            self.processor,
            (
                AttnAddedKVProcessor,
                AttnAddedKVProcessor2_0,
                SlicedAttnAddedKVProcessor,
                XFormersAttnAddedKVProcessor,
                LoRAAttnAddedKVProcessor,
            ),
        )

        if use_memory_efficient_attention_xformers:
            if is_added_kv_processor and (is_lora or is_custom_diffusion):
                raise NotImplementedError(
                    f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
                )
            if not is_xformers_available():
                raise ModuleNotFoundError(
                    (
                        "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
                        " xformers"
                    ),
                    name="xformers",
                )
            elif not torch.cuda.is_available():
                raise ValueError(
                    "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
                    " only available for GPU "
                )
            else:
                try:
                    # Make sure we can run the memory efficient attention
                    _ = xformers.ops.memory_efficient_attention(
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                        torch.randn((1, 2, 40), device="cuda"),
                    )
                except Exception as e:
                    raise e

            if is_lora:
                # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
                # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
                processor = LoRAXFormersAttnProcessor(
                    hidden_size=self.processor.hidden_size,
                    cross_attention_dim=self.processor.cross_attention_dim,
                    rank=self.processor.rank,
                    attention_op=attention_op,
                )
                processor.load_state_dict(self.processor.state_dict())
                processor.to(self.processor.to_q_lora.up.weight.device)
            elif is_custom_diffusion:
                processor = CustomDiffusionXFormersAttnProcessor(
                    train_kv=self.processor.train_kv,
                    train_q_out=self.processor.train_q_out,
                    hidden_size=self.processor.hidden_size,
                    cross_attention_dim=self.processor.cross_attention_dim,
                    attention_op=attention_op,
                )
                processor.load_state_dict(self.processor.state_dict())
                if hasattr(self.processor, "to_k_custom_diffusion"):
                    processor.to(self.processor.to_k_custom_diffusion.weight.device)
            elif is_added_kv_processor:
                # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
                # which uses this type of cross attention ONLY because the attention mask of format
                # [0, ..., -10.000, ..., 0, ...,] is not supported
                # throw warning
                logger.info(
                    "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
                )
                processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
            else:
                processor = XFormersAttnProcessor(attention_op=attention_op)
        else:
            if is_lora:
                attn_processor_class = (
                    LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
                )
                processor = attn_processor_class(
                    hidden_size=self.processor.hidden_size,
                    cross_attention_dim=self.processor.cross_attention_dim,
                    rank=self.processor.rank,
                )
                processor.load_state_dict(self.processor.state_dict())
                processor.to(self.processor.to_q_lora.up.weight.device)
            elif is_custom_diffusion:
                attn_processor_class = (
                    CustomDiffusionAttnProcessor2_0
                    if hasattr(F, "scaled_dot_product_attention")
                    else CustomDiffusionAttnProcessor
                )
                processor = attn_processor_class(
                    train_kv=self.processor.train_kv,
                    train_q_out=self.processor.train_q_out,
                    hidden_size=self.processor.hidden_size,
                    cross_attention_dim=self.processor.cross_attention_dim,
                )
                processor.load_state_dict(self.processor.state_dict())
                if hasattr(self.processor, "to_k_custom_diffusion"):
                    processor.to(self.processor.to_k_custom_diffusion.weight.device)
            else:
                # set attention processor
                # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
                # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
                # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
                processor = (
                    AttnProcessor2_0()
                    if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
                    else AttnProcessor()
                )

        self.set_processor(processor)

    def set_attention_slice(self, slice_size: int) -> None:
        r"""
        Set the slice size for attention computation.

        Args:
            slice_size (`int`):
                The slice size for attention computation.
        """
        if slice_size is not None and slice_size > self.sliceable_head_dim:
            raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")

        if slice_size is not None and self.added_kv_proj_dim is not None:
            processor = SlicedAttnAddedKVProcessor(slice_size)
        elif slice_size is not None:
            processor = SlicedAttnProcessor(slice_size)
        elif self.added_kv_proj_dim is not None:
            processor = AttnAddedKVProcessor()
        else:
            # set attention processor
            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
            processor = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
            )

        self.set_processor(processor)

    def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
        r"""
        Set the attention processor to use.

        Args:
            processor (`AttnProcessor`):
                The attention processor to use.
            _remove_lora (`bool`, *optional*, defaults to `False`):
                Set to `True` to remove LoRA layers from the model.
        """
        if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
            deprecate(
                "set_processor to offload LoRA",
                "0.26.0",
                "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
            )
            # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
            # We need to remove all LoRA layers
            # Don't forget to remove ALL `_remove_lora` from the codebase
            for module in self.modules():
                if hasattr(module, "set_lora_layer"):
                    module.set_lora_layer(None)

        # if current processor is in `self._modules` and if passed `processor` is not, we need to
        # pop `processor` from `self._modules`
        if (
            hasattr(self, "processor")
            and isinstance(self.processor, torch.nn.Module)
            and not isinstance(processor, torch.nn.Module)
        ):
            logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
            self._modules.pop("processor")

        self.processor = processor

    def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
        r"""
        Get the attention processor in use.

        Args:
            return_deprecated_lora (`bool`, *optional*, defaults to `False`):
                Set to `True` to return the deprecated LoRA attention processor.

        Returns:
            "AttentionProcessor": The attention processor in use.
        """
        if not return_deprecated_lora:
            return self.processor

        # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
        # serialization format for LoRA Attention Processors. It should be deleted once the integration
        # with PEFT is completed.
        is_lora_activated = {
            name: module.lora_layer is not None
            for name, module in self.named_modules()
            if hasattr(module, "lora_layer")
        }

        # 1. if no layer has a LoRA activated we can return the processor as usual
        if not any(is_lora_activated.values()):
            return self.processor

        # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
        is_lora_activated.pop("add_k_proj", None)
        is_lora_activated.pop("add_v_proj", None)
        # 2. else it is not possible that only some layers have LoRA activated
        if not all(is_lora_activated.values()):
            raise ValueError(
                f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
            )

        # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
        non_lora_processor_cls_name = self.processor.__class__.__name__
        lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)

        hidden_size = self.inner_dim

        # now create a LoRA attention processor from the LoRA layers
        if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
            kwargs = {
                "cross_attention_dim": self.cross_attention_dim,
                "rank": self.to_q.lora_layer.rank,
                "network_alpha": self.to_q.lora_layer.network_alpha,
                "q_rank": self.to_q.lora_layer.rank,
                "q_hidden_size": self.to_q.lora_layer.out_features,
                "k_rank": self.to_k.lora_layer.rank,
                "k_hidden_size": self.to_k.lora_layer.out_features,
                "v_rank": self.to_v.lora_layer.rank,
                "v_hidden_size": self.to_v.lora_layer.out_features,
                "out_rank": self.to_out[0].lora_layer.rank,
                "out_hidden_size": self.to_out[0].lora_layer.out_features,
            }

            if hasattr(self.processor, "attention_op"):
                kwargs["attention_op"] = self.processor.attention_op

            lora_processor = lora_processor_cls(hidden_size, **kwargs)
            lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
            lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
            lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
            lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
        elif lora_processor_cls == LoRAAttnAddedKVProcessor:
            lora_processor = lora_processor_cls(
                hidden_size,
                cross_attention_dim=self.add_k_proj.weight.shape[0],
                rank=self.to_q.lora_layer.rank,
                network_alpha=self.to_q.lora_layer.network_alpha,
            )
            lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
            lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
            lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
            lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())

            # only save if used
            if self.add_k_proj.lora_layer is not None:
                lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
                lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
            else:
                lora_processor.add_k_proj_lora = None
                lora_processor.add_v_proj_lora = None
        else:
            raise ValueError(f"{lora_processor_cls} does not exist.")

        return lora_processor

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        **cross_attention_kwargs,
    ) -> torch.Tensor:
        r"""
        The forward method of the `Attention` class.

        Args:
            hidden_states (`torch.Tensor`):
                The hidden states of the query.
            encoder_hidden_states (`torch.Tensor`, *optional*):
                The hidden states of the encoder.
            attention_mask (`torch.Tensor`, *optional*):
                The attention mask to use. If `None`, no mask is applied.
            **cross_attention_kwargs:
                Additional keyword arguments to pass along to the cross attention.

        Returns:
            `torch.Tensor`: The output of the attention layer.
        """
        # The `Attention` class can call different attention processors / attention functions
        # here we simply pass along all tensors to the selected processor class
        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
        return self.processor(
            self,
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )

    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
        r"""
        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
        is the number of heads initialized while constructing the `Attention` class.

        Args:
            tensor (`torch.Tensor`): The tensor to reshape.

        Returns:
            `torch.Tensor`: The reshaped tensor.
        """
        head_size = self.heads
        batch_size, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
        r"""
        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
        the number of heads initialized while constructing the `Attention` class.

        Args:
            tensor (`torch.Tensor`): The tensor to reshape.
            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
                reshaped to `[batch_size * heads, seq_len, dim // heads]`.

        Returns:
            `torch.Tensor`: The reshaped tensor.
        """
        head_size = self.heads
        batch_size, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3)

        if out_dim == 3:
            tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)

        return tensor

    def get_attention_scores(
        self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        r"""
        Compute the attention scores.

        Args:
            query (`torch.Tensor`): The query tensor.
            key (`torch.Tensor`): The key tensor.
            attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.

        Returns:
            `torch.Tensor`: The attention probabilities/scores.
        """
        dtype = query.dtype
        if self.upcast_attention:
            query = query.float()
            key = key.float()

        if attention_mask is None:
            baddbmm_input = torch.empty(
                query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
            )
            beta = 0
        else:
            baddbmm_input = attention_mask
            beta = 1

        attention_scores = torch.baddbmm(
            baddbmm_input,
            query,
            key.transpose(-1, -2),
            beta=beta,
            alpha=self.scale,
        )
        del baddbmm_input

        if self.upcast_softmax:
            attention_scores = attention_scores.float()

        attention_probs = attention_scores.softmax(dim=-1)
        del attention_scores

        attention_probs = attention_probs.to(dtype)

        return attention_probs

    def prepare_attention_mask(
        self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
    ) -> torch.Tensor:
        r"""
        Prepare the attention mask for the attention computation.

        Args:
            attention_mask (`torch.Tensor`):
                The attention mask to prepare.
            target_length (`int`):
                The target length of the attention mask. This is the length of the attention mask after padding.
            batch_size (`int`):
                The batch size, which is used to repeat the attention mask.
            out_dim (`int`, *optional*, defaults to `3`):
                The output dimension of the attention mask. Can be either `3` or `4`.

        Returns:
            `torch.Tensor`: The prepared attention mask.
        """
        head_size = self.heads
        if attention_mask is None:
            return attention_mask

        current_length: int = attention_mask.shape[-1]
        if current_length != target_length:
            if attention_mask.device.type == "mps":
                # HACK: MPS: Does not support padding by greater than dimension of input tensor.
                # Instead, we can manually construct the padding tensor.
                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
                attention_mask = torch.cat([attention_mask, padding], dim=2)
            else:
                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
                #       we want to instead pad by (0, remaining_length), where remaining_length is:
                #       remaining_length: int = target_length - current_length
                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

        if out_dim == 3:
            if attention_mask.shape[0] < batch_size * head_size:
                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
        elif out_dim == 4:
            attention_mask = attention_mask.unsqueeze(1)
            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)

        return attention_mask

    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        r"""
        Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
        `Attention` class.

        Args:
            encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.

        Returns:
            `torch.Tensor`: The normalized encoder hidden states.
        """
        assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"

        if isinstance(self.norm_cross, nn.LayerNorm):
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
        elif isinstance(self.norm_cross, nn.GroupNorm):
            # Group norm norms along the channels dimension and expects
            # input to be in the shape of (N, C, *). In this case, we want
            # to norm along the hidden dimension, so we need to move
            # (batch_size, sequence_length, hidden_size) ->
            # (batch_size, hidden_size, sequence_length)
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
        else:
            assert False

        return encoder_hidden_states

    @torch.no_grad()
    def fuse_projections(self, fuse=True):
        is_cross_attention = self.cross_attention_dim != self.query_dim
        device = self.to_q.weight.data.device
        dtype = self.to_q.weight.data.dtype

        if not is_cross_attention:
            # fetch weight matrices.
            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            # create a new single projection layer and copy over the weights.
            self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
            self.to_qkv.weight.copy_(concatenated_weights)

        else:
            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
            self.to_kv.weight.copy_(concatenated_weights)

        self.fused_projections = fuse


class AttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
    ) -> torch.Tensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states, *args)
        value = attn.to_v(encoder_hidden_states, *args)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class CustomDiffusionAttnProcessor(nn.Module):
    r"""
    Processor for implementing attention for the Custom Diffusion method.

    Args:
        train_kv (`bool`, defaults to `True`):
            Whether to newly train the key and value matrices corresponding to the text features.
        train_q_out (`bool`, defaults to `True`):
            Whether to newly train query matrices corresponding to the latent image features.
        hidden_size (`int`, *optional*, defaults to `None`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`, *optional*, defaults to `None`):
            The number of channels in the `encoder_hidden_states`.
        out_bias (`bool`, defaults to `True`):
            Whether to include the bias parameter in `train_q_out`.
        dropout (`float`, *optional*, defaults to 0.0):
            The dropout probability to use.
    """

    def __init__(
        self,
        train_kv: bool = True,
        train_q_out: bool = True,
        hidden_size: Optional[int] = None,
        cross_attention_dim: Optional[int] = None,
        out_bias: bool = True,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.train_kv = train_kv
        self.train_q_out = train_q_out

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim

        # `_custom_diffusion` id for easy serialization and loading.
        if self.train_kv:
            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        if self.train_q_out:
            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
            self.to_out_custom_diffusion = nn.ModuleList([])
            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
            self.to_out_custom_diffusion.append(nn.Dropout(dropout))

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
    ) -> torch.Tensor:
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        if self.train_q_out:
            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
        else:
            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))

        if encoder_hidden_states is None:
            crossattn = False
            encoder_hidden_states = hidden_states
        else:
            crossattn = True
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        if self.train_kv:
            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
            key = key.to(attn.to_q.weight.dtype)
            value = value.to(attn.to_q.weight.dtype)
        else:
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)

        if crossattn:
            detach = torch.ones_like(key)
            detach[:, :1, :] = detach[:, :1, :] * 0.0
            key = detach * key + (1 - detach) * key.detach()
            value = detach * value + (1 - detach) * value.detach()

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        if self.train_q_out:
            # linear proj
            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
            # dropout
            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
        else:
            # linear proj
            hidden_states = attn.to_out[0](hidden_states)
            # dropout
            hidden_states = attn.to_out[1](hidden_states)

        return hidden_states


class AttnAddedKVProcessor:
    r"""
    Processor for performing attention-related computations with extra learnable key and value matrices for the text
    encoder.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
    ) -> torch.Tensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        batch_size, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)
        query = attn.head_to_batch_dim(query)

        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

        if not attn.only_cross_attention:
            key = attn.to_k(hidden_states, *args)
            value = attn.to_v(hidden_states, *args)
            key = attn.head_to_batch_dim(key)
            value = attn.head_to_batch_dim(value)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
        else:
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        hidden_states = hidden_states + residual

        return hidden_states


class AttnAddedKVProcessor2_0:
    r"""
    Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
    learnable key and value matrices for the text encoder.
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
    ) -> torch.Tensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        batch_size, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)
        query = attn.head_to_batch_dim(query, out_dim=4)

        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)

        if not attn.only_cross_attention:
            key = attn.to_k(hidden_states, *args)
            value = attn.to_v(hidden_states, *args)
            key = attn.head_to_batch_dim(key, out_dim=4)
            value = attn.head_to_batch_dim(value, out_dim=4)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
        else:
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        hidden_states = hidden_states + residual

        return hidden_states


class XFormersAttnAddedKVProcessor:
    r"""
    Processor for implementing memory efficient attention using xFormers.

    Args:
        attention_op (`Callable`, *optional*, defaults to `None`):
            The base
            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
            operator.
    """

    def __init__(self, attention_op: Optional[Callable] = None):
        self.attention_op = attention_op

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        batch_size, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)
        query = attn.head_to_batch_dim(query)

        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

        if not attn.only_cross_attention:
            key = attn.to_k(hidden_states)
            value = attn.to_v(hidden_states)
            key = attn.head_to_batch_dim(key)
            value = attn.head_to_batch_dim(value)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
        else:
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        hidden_states = xformers.ops.memory_efficient_attention(
            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
        )
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        hidden_states = hidden_states + residual

        return hidden_states


class XFormersAttnProcessor:
    r"""
    Processor for implementing memory efficient attention using xFormers.

    Args:
        attention_op (`Callable`, *optional*, defaults to `None`):
            The base
            [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
            use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
            operator.
    """

    def __init__(self, attention_op: Optional[Callable] = None):
        self.attention_op = attention_op

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
    ) -> torch.FloatTensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, key_tokens, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
        if attention_mask is not None:
            # expand our mask's singleton query_tokens dimension:
            #   [batch*heads,            1, key_tokens] ->
            #   [batch*heads, query_tokens, key_tokens]
            # so that it can be added as a bias onto the attention scores that xformers computes:
            #   [batch*heads, query_tokens, key_tokens]
            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
            _, query_tokens, _ = hidden_states.shape
            attention_mask = attention_mask.expand(-1, query_tokens, -1)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states, *args)
        value = attn.to_v(encoder_hidden_states, *args)

        query = attn.head_to_batch_dim(query).contiguous()
        key = attn.head_to_batch_dim(key).contiguous()
        value = attn.head_to_batch_dim(value).contiguous()

        hidden_states = xformers.ops.memory_efficient_attention(
            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
        )
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class AttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
        **kwargs,
    ) -> torch.FloatTensor:
        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        args = () if USE_PEFT_BACKEND else (scale,)
        query = attn.to_q(hidden_states, *args)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_s
Download .txt
gitextract_36et0ept/

├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── app.py
├── foleycrafter/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── video_transforms.py
│   ├── models/
│   │   ├── adapters/
│   │   │   ├── attention_processor.py
│   │   │   ├── ip_adapter.py
│   │   │   ├── resampler.py
│   │   │   ├── transformer.py
│   │   │   └── utils.py
│   │   ├── auffusion/
│   │   │   ├── attention.py
│   │   │   ├── attention_processor.py
│   │   │   ├── dual_transformer_2d.py
│   │   │   ├── loaders/
│   │   │   │   ├── ip_adapter.py
│   │   │   │   └── unet.py
│   │   │   ├── resnet.py
│   │   │   ├── transformer_2d.py
│   │   │   └── unet_2d_blocks.py
│   │   ├── auffusion_unet.py
│   │   ├── onset/
│   │   │   ├── __init__.py
│   │   │   ├── r2plus1d_18.py
│   │   │   ├── resnet.py
│   │   │   ├── torch_utils.py
│   │   │   └── video_onset_net.py
│   │   └── time_detector/
│   │       ├── model.py
│   │       └── resnet.py
│   ├── pipelines/
│   │   ├── auffusion_pipeline.py
│   │   └── pipeline_controlnet.py
│   └── utils/
│       ├── converter.py
│       ├── spec_to_mel.py
│       └── util.py
├── inference.py
├── pyproject.toml
└── requirements/
    └── environment.yaml
Download .txt
SYMBOL INDEX (700 symbols across 29 files)

FILE: app.py
  class FoleyController (line 55) | class FoleyController:
    method __init__ (line 56) | def __init__(self):
    method load_model (line 70) | def load_model(self):
    method foley (line 127) | def foley(

FILE: foleycrafter/data/dataset.py
  function zero_rank_print (line 14) | def zero_rank_print(s):
  function get_mel (line 20) | def get_mel(audio_data, audio_cfg):
  function dynamic_range_compression (line 42) | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=...
  class CPU_Unpickler (line 51) | class CPU_Unpickler(pickle.Unpickler):
    method find_class (line 52) | def find_class(self, module, name):
  class AudioSetStrong (line 59) | class AudioSetStrong(Dataset):
    method __init__ (line 61) | def __init__(
    method get_batch (line 79) | def get_batch(self, idx):
    method __len__ (line 94) | def __len__(self):
    method __getitem__ (line 97) | def __getitem__(self, idx):
  class VGGSound (line 115) | class VGGSound(Dataset):
    method __init__ (line 117) | def __init__(
    method get_batch (line 129) | def get_batch(self, idx):
    method __len__ (line 143) | def __len__(self):
    method __getitem__ (line 146) | def __getitem__(self, idx):

FILE: foleycrafter/data/video_transforms.py
  function _is_tensor_video_clip (line 7) | def _is_tensor_video_clip(clip):
  function crop (line 17) | def crop(clip, i, j, h, w):
  function resize (line 27) | def resize(clip, target_size, interpolation_mode):
  function resize_scale (line 33) | def resize_scale(clip, target_size, interpolation_mode):
  function resized_crop (line 41) | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
  function center_crop (line 61) | def center_crop(clip, crop_size):
  function random_shift_crop (line 74) | def random_shift_crop(clip):
  function to_tensor (line 96) | def to_tensor(clip):
  function normalize (line 112) | def normalize(clip, mean, std, inplace=False):
  function hflip (line 132) | def hflip(clip):
  class RandomCropVideo (line 144) | class RandomCropVideo:
    method __init__ (line 145) | def __init__(self, size):
    method __call__ (line 151) | def __call__(self, clip):
    method get_params (line 162) | def get_params(self, clip):
    method __repr__ (line 177) | def __repr__(self) -> str:
  class UCFCenterCropVideo (line 181) | class UCFCenterCropVideo:
    method __init__ (line 182) | def __init__(
    method __call__ (line 196) | def __call__(self, clip):
    method __repr__ (line 208) | def __repr__(self) -> str:
  class KineticsRandomCropResizeVideo (line 212) | class KineticsRandomCropResizeVideo:
    method __init__ (line 217) | def __init__(
    method __call__ (line 231) | def __call__(self, clip):
  class CenterCropVideo (line 237) | class CenterCropVideo:
    method __init__ (line 238) | def __init__(
    method __call__ (line 252) | def __call__(self, clip):
    method __repr__ (line 263) | def __repr__(self) -> str:
  class NormalizeVideo (line 267) | class NormalizeVideo:
    method __init__ (line 276) | def __init__(self, mean, std, inplace=False):
    method __call__ (line 281) | def __call__(self, clip):
    method __repr__ (line 288) | def __repr__(self) -> str:
  class ToTensorVideo (line 292) | class ToTensorVideo:
    method __init__ (line 298) | def __init__(self):
    method __call__ (line 301) | def __call__(self, clip):
    method __repr__ (line 310) | def __repr__(self) -> str:
  class RandomHorizontalFlipVideo (line 314) | class RandomHorizontalFlipVideo:
    method __init__ (line 321) | def __init__(self, p=0.5):
    method __call__ (line 324) | def __call__(self, clip):
    method __repr__ (line 335) | def __repr__(self) -> str:
  class TemporalRandomCrop (line 342) | class TemporalRandomCrop(object):
    method __init__ (line 349) | def __init__(self, size):
    method __call__ (line 352) | def __call__(self, total_frames):

FILE: foleycrafter/models/adapters/attention_processor.py
  class AttnProcessor (line 11) | class AttnProcessor(nn.Module):
    method __init__ (line 16) | def __init__(
    method __call__ (line 23) | def __call__(
  class IPAttnProcessor (line 84) | class IPAttnProcessor(nn.Module):
    method __init__ (line 98) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
    method __call__ (line 109) | def __call__(
  class AttnProcessor2_0 (line 191) | class AttnProcessor2_0(torch.nn.Module):
    method __init__ (line 196) | def __init__(
    method __call__ (line 205) | def __call__(
  class AttnProcessor2_0WithProjection (line 280) | class AttnProcessor2_0WithProjection(torch.nn.Module):
    method __init__ (line 285) | def __init__(
    method __call__ (line 297) | def __call__(
  class IPAttnProcessor2_0 (line 373) | class IPAttnProcessor2_0(torch.nn.Module):
    method __init__ (line 387) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
    method __call__ (line 401) | def __call__(
  class CNAttnProcessor (line 505) | class CNAttnProcessor:
    method __init__ (line 510) | def __init__(self, num_tokens=4):
    method __call__ (line 513) | def __call__(self, attn, hidden_states, encoder_hidden_states=None, at...
  class CNAttnProcessor2_0 (line 570) | class CNAttnProcessor2_0:
    method __init__ (line 575) | def __init__(self, num_tokens=4):
    method __call__ (line 580) | def __call__(

FILE: foleycrafter/models/adapters/ip_adapter.py
  class IPAdapter (line 5) | class IPAdapter(torch.nn.Module):
    method __init__ (line 8) | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=...
    method forward (line 17) | def forward(self, noisy_latents, timesteps, encoder_hidden_states, ima...
    method load_from_checkpoint (line 24) | def load_from_checkpoint(self, ckpt_path: str):
  class ImageProjModel (line 46) | class ImageProjModel(torch.nn.Module):
    method __init__ (line 49) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method forward (line 57) | def forward(self, image_embeds):
  class MLPProjModel (line 66) | class MLPProjModel(torch.nn.Module):
    method zero_initialize (line 69) | def zero_initialize(module):
    method zero_initialize_last_layer (line 73) | def zero_initialize_last_layer(module):
    method __init__ (line 83) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
    method forward (line 95) | def forward(self, image_embeds):
  class V2AMapperMLP (line 100) | class V2AMapperMLP(torch.nn.Module):
    method __init__ (line 101) | def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, m...
    method forward (line 110) | def forward(self, image_embeds):
  class TimeProjModel (line 115) | class TimeProjModel(torch.nn.Module):
    method __init__ (line 116) | def __init__(self, positive_len, out_dim, feature_type="text-only", fr...
    method forward (line 156) | def forward(

FILE: foleycrafter/models/adapters/resampler.py
  function FeedForward (line 13) | def FeedForward(dim, mult=4):
  function reshape_tensor (line 23) | def reshape_tensor(x, heads):
  class PerceiverAttention (line 34) | class PerceiverAttention(nn.Module):
    method __init__ (line 35) | def __init__(self, *, dim, dim_head=64, heads=8):
    method forward (line 49) | def forward(self, x, latents):
  class Resampler (line 81) | class Resampler(nn.Module):
    method __init__ (line 82) | def __init__(
    method forward (line 127) | def forward(self, x):
  function masked_mean (line 150) | def masked_mean(t, *, dim, mask=None):

FILE: foleycrafter/models/adapters/transformer.py
  class Attention (line 8) | class Attention(nn.Module):
    method __init__ (line 11) | def __init__(self, hidden_size, num_attention_heads, attention_head_di...
    method _shape (line 27) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 30) | def forward(
  class MLP (line 109) | class MLP(nn.Module):
    method __init__ (line 110) | def __init__(self, hidden_size, intermediate_size, mult=4):
    method forward (line 116) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  class Transformer (line 123) | class Transformer(nn.Module):
    method __init__ (line 124) | def __init__(self, depth=12):
    method forward (line 128) | def forward(
  class TransformerBlock (line 156) | class TransformerBlock(nn.Module):
    method __init__ (line 157) | def __init__(
    method forward (line 175) | def forward(
  class DiffusionTransformerBlock (line 216) | class DiffusionTransformerBlock(nn.Module):
    method __init__ (line 217) | def __init__(
    method forward (line 236) | def forward(
  class V2AMapperMLP (line 279) | class V2AMapperMLP(nn.Module):
    method __init__ (line 280) | def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
    method forward (line 287) | def forward(self, x):
  class ImageProjModel (line 296) | class ImageProjModel(torch.nn.Module):
    method __init__ (line 299) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
    method zero_initialize_last_layer (line 309) | def zero_initialize_last_layer(module):
    method forward (line 319) | def forward(self, image_embeds):
  class VisionAudioAdapter (line 328) | class VisionAudioAdapter(torch.nn.Module):
    method __init__ (line 329) | def __init__(
    method forward (line 349) | def forward(self, image_embeds):

FILE: foleycrafter/models/adapters/utils.py
  function hook_fn (line 10) | def hook_fn(name):
  function register_cross_attention_hook (line 19) | def register_cross_attention_hook(unet):
  function upscale (line 27) | def upscale(attn_map, target_size):
  function get_net_attn_map (line 50) | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=Fals...
  function attnmaps2images (line 65) | def attnmaps2images(net_attn_maps):
  function is_torch2_available (line 85) | def is_torch2_available():

FILE: foleycrafter/models/auffusion/attention.py
  function _chunked_feed_forward (line 29) | def _chunked_feed_forward(
  class GatedSelfAttentionDense (line 55) | class GatedSelfAttentionDense(nn.Module):
    method __init__ (line 66) | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_h...
    method forward (line 83) | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
  class BasicTransformerBlock (line 97) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 132) | def __init__(
    method set_chunk_feed_forward (line 279) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int =...
    method forward (line 284) | def forward(
  class TemporalBasicTransformerBlock (line 408) | class TemporalBasicTransformerBlock(nn.Module):
    method __init__ (line 420) | def __init__(
    method set_chunk_feed_forward (line 474) | def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
    method forward (line 480) | def forward(
  class SkipFFTransformerBlock (line 538) | class SkipFFTransformerBlock(nn.Module):
    method __init__ (line 539) | def __init__(
    method forward (line 581) | def forward(self, hidden_states, encoder_hidden_states, cross_attentio...
  class FeedForward (line 610) | class FeedForward(nn.Module):
    method __init__ (line 624) | def __init__(
    method forward (line 661) | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> ...

FILE: foleycrafter/models/auffusion/attention_processor.py
  class Attention (line 40) | class Attention(nn.Module):
    method __init__ (line 91) | def __init__(
    method set_use_memory_efficient_attention_xformers (line 216) | def set_use_memory_efficient_attention_xformers(
    method set_attention_slice (line 350) | def set_attention_slice(self, slice_size: int) -> None:
    method set_processor (line 378) | def set_processor(self, processor: "AttnProcessor", _remove_lora: bool...
    method get_processor (line 413) | def get_processor(self, return_deprecated_lora: bool = False) -> "Atte...
    method forward (line 503) | def forward(
    method batch_to_head_dim (line 537) | def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
    method head_to_batch_dim (line 554) | def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) ->...
    method get_attention_scores (line 577) | def get_attention_scores(
    method prepare_attention_mask (line 624) | def prepare_attention_mask(
    method norm_encoder_hidden_states (line 671) | def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tens...
    method fuse_projections (line 701) | def fuse_projections(self, fuse=True):
  class AttnProcessor (line 727) | class AttnProcessor:
    method __call__ (line 732) | def __call__(
  class CustomDiffusionAttnProcessor (line 796) | class CustomDiffusionAttnProcessor(nn.Module):
    method __init__ (line 815) | def __init__(
    method __call__ (line 841) | def __call__(
  class AttnAddedKVProcessor (line 900) | class AttnAddedKVProcessor:
    method __call__ (line 906) | def __call__(
  class AttnAddedKVProcessor2_0 (line 964) | class AttnAddedKVProcessor2_0:
    method __init__ (line 970) | def __init__(self):
    method __call__ (line 976) | def __call__(
  class XFormersAttnAddedKVProcessor (line 1037) | class XFormersAttnAddedKVProcessor:
    method __init__ (line 1049) | def __init__(self, attention_op: Optional[Callable] = None):
    method __call__ (line 1052) | def __call__(
  class XFormersAttnProcessor (line 1108) | class XFormersAttnProcessor:
    method __init__ (line 1120) | def __init__(self, attention_op: Optional[Callable] = None):
    method __call__ (line 1123) | def __call__(
  class AttnProcessor2_0 (line 1199) | class AttnProcessor2_0:
    method __init__ (line 1204) | def __init__(self):
    method __call__ (line 1208) | def __call__(
  class FusedAttnProcessor2_0 (line 1285) | class FusedAttnProcessor2_0:
    method __init__ (line 1298) | def __init__(self):
    method __call__ (line 1304) | def __call__(
  class CustomDiffusionXFormersAttnProcessor (line 1382) | class CustomDiffusionXFormersAttnProcessor(nn.Module):
    method __init__ (line 1405) | def __init__(
    method __call__ (line 1433) | def __call__(
  class CustomDiffusionAttnProcessor2_0 (line 1498) | class CustomDiffusionAttnProcessor2_0(nn.Module):
    method __init__ (line 1518) | def __init__(
    method __call__ (line 1544) | def __call__(
  class SlicedAttnProcessor (line 1612) | class SlicedAttnProcessor:
    method __init__ (line 1622) | def __init__(self, slice_size: int):
    method __call__ (line 1625) | def __call__(
  class SlicedAttnAddedKVProcessor (line 1699) | class SlicedAttnAddedKVProcessor:
    method __init__ (line 1709) | def __init__(self, slice_size):
    method __call__ (line 1712) | def __call__(
  class SpatialNorm (line 1791) | class SpatialNorm(nn.Module):
    method __init__ (line 1802) | def __init__(
    method forward (line 1812) | def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torc...
  class LoRAAttnProcessor (line 1821) | class LoRAAttnProcessor(nn.Module):
    method __init__ (line 1838) | def __init__(
    method __call__ (line 1872) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
  class LoRAAttnProcessor2_0 (line 1893) | class LoRAAttnProcessor2_0(nn.Module):
    method __init__ (line 1911) | def __init__(
    method __call__ (line 1947) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
  class LoRAXFormersAttnProcessor (line 1968) | class LoRAXFormersAttnProcessor(nn.Module):
    method __init__ (line 1990) | def __init__(
    method __call__ (line 2026) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
  class LoRAAttnAddedKVProcessor (line 2047) | class LoRAAttnAddedKVProcessor(nn.Module):
    method __init__ (line 2065) | def __init__(
    method __call__ (line 2085) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
  class IPAdapterAttnProcessor (line 2106) | class IPAdapterAttnProcessor(nn.Module):
    method __init__ (line 2121) | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4...
    method __call__ (line 2132) | def __call__(
  class VPTemporalAdapterAttnProcessor2_0 (line 2216) | class VPTemporalAdapterAttnProcessor2_0(torch.nn.Module):
    method __init__ (line 2238) | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(...
    method __call__ (line 2266) | def __call__(
  class IPAdapterAttnProcessor2_0 (line 2450) | class IPAdapterAttnProcessor2_0(torch.nn.Module):
    method __init__ (line 2465) | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(...
    method __call__ (line 2492) | def __call__(

FILE: foleycrafter/models/auffusion/dual_transformer_2d.py
  class DualTransformer2DModel (line 21) | class DualTransformer2DModel(nn.Module):
    method __init__ (line 48) | def __init__(
    method forward (line 97) | def forward(

FILE: foleycrafter/models/auffusion/loaders/ip_adapter.py
  class IPAdapterMixin (line 51) | class IPAdapterMixin:
    method load_ip_adapter (line 55) | def load_ip_adapter(
    method set_ip_adapter_scale (line 235) | def set_ip_adapter_scale(self, scale):
    method unload_ip_adapter (line 257) | def unload_ip_adapter(self):
  class VPAdapterMixin (line 289) | class VPAdapterMixin:
    method load_ip_adapter (line 293) | def load_ip_adapter(
    method set_ip_adapter_scale (line 473) | def set_ip_adapter_scale(self, scale):
    method unload_ip_adapter (line 495) | def unload_ip_adapter(self):

FILE: foleycrafter/models/auffusion/loaders/unet.py
  class VPAdapterImageProjection (line 53) | class VPAdapterImageProjection(nn.Module):
    method __init__ (line 54) | def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Modul...
    method forward (line 58) | def forward(self, image_embeds: List[torch.FloatTensor]):
  class MultiIPAdapterImageProjection (line 88) | class MultiIPAdapterImageProjection(nn.Module):
    method __init__ (line 89) | def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Modul...
    method forward (line 93) | def forward(self, image_embeds: List[torch.FloatTensor]):
  class UNet2DConditionLoadersMixin (line 132) | class UNet2DConditionLoadersMixin:
    method load_attn_procs (line 141) | def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union...
    method convert_state_dict_legacy_attn_format (line 454) | def convert_state_dict_legacy_attn_format(self, state_dict, network_al...
    method save_attn_procs (line 481) | def save_attn_procs(
    method fuse_lora (line 583) | def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=N...
    method _fuse_lora_apply (line 588) | def _fuse_lora_apply(self, module, adapter_names=None):
    method unfuse_lora (line 600) | def unfuse_lora(self):
    method _unfuse_lora_apply (line 603) | def _unfuse_lora_apply(self, module):
    method set_adapters (line 608) | def set_adapters(
    method disable_lora (line 656) | def disable_lora(self):
    method enable_lora (line 679) | def enable_lora(self):
    method delete_adapters (line 702) | def delete_adapters(self, adapter_names: Union[List[str], str]):
    method _convert_ip_adapter_image_proj_to_diffusers (line 738) | def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_...
    method _convert_ip_adapter_attn_to_diffusers_VPAdapter (line 786) | def _convert_ip_adapter_attn_to_diffusers_VPAdapter(self, state_dicts,...
    method _convert_ip_adapter_attn_to_diffusers (line 873) | def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_m...
    method _load_ip_adapter_weights (line 958) | def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
    method _load_ip_adapter_weights_VPAdapter (line 975) | def _load_ip_adapter_weights_VPAdapter(self, state_dicts, low_cpu_mem_...

FILE: foleycrafter/models/auffusion/resnet.py
  class ResnetBlock2D (line 45) | class ResnetBlock2D(nn.Module):
    method __init__ (line 76) | def __init__(
    method forward (line 183) | def forward(
  function rearrange_dims (line 265) | def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
  class Conv1dBlock (line 276) | class Conv1dBlock(nn.Module):
    method __init__ (line 288) | def __init__(
    method forward (line 302) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  class ResidualTemporalBlock1D (line 312) | class ResidualTemporalBlock1D(nn.Module):
    method __init__ (line 324) | def __init__(
    method forward (line 343) | def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
  class TemporalConvLayer (line 359) | class TemporalConvLayer(nn.Module):
    method __init__ (line 370) | def __init__(
    method forward (line 411) | def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) ->...
  class TemporalResnetBlock (line 430) | class TemporalResnetBlock(nn.Module):
    method __init__ (line 442) | def __init__(
    method forward (line 496) | def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTe...
  class SpatioTemporalResBlock (line 523) | class SpatioTemporalResBlock(nn.Module):
    method __init__ (line 541) | def __init__(
    method forward (line 574) | def forward(
  class AlphaBlender (line 607) | class AlphaBlender(nn.Module):
    method __init__ (line 621) | def __init__(
    method get_alpha (line 641) | def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) ->...
    method forward (line 672) | def forward(

FILE: foleycrafter/models/auffusion/transformer_2d.py
  class Transformer2DModelOutput (line 31) | class Transformer2DModelOutput(BaseOutput):
  class Transformer2DModel (line 44) | class Transformer2DModel(ModelMixin, ConfigMixin):
    method __init__ (line 75) | def __init__(
    method _set_gradient_checkpointing (line 242) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 246) | def forward(

FILE: foleycrafter/models/auffusion/unet_2d_blocks.py
  function get_down_block (line 42) | def get_down_block(
  function get_up_block (line 251) | def get_up_block(
  class AutoencoderTinyBlock (line 476) | class AutoencoderTinyBlock(nn.Module):
    method __init__ (line 492) | def __init__(self, in_channels: int, out_channels: int, act_fn: str):
    method forward (line 509) | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
  class UNetMidBlock2D (line 513) | class UNetMidBlock2D(nn.Module):
    method __init__ (line 544) | def __init__(
    method forward (line 628) | def forward(self, hidden_states: torch.FloatTensor, temb: Optional[tor...
  class UNetMidBlock2DCrossAttn (line 638) | class UNetMidBlock2DCrossAttn(nn.Module):
    method __init__ (line 639) | def __init__(
    method forward (line 732) | def forward(
  class UNetMidBlock2DSimpleCrossAttn (line 784) | class UNetMidBlock2DSimpleCrossAttn(nn.Module):
    method __init__ (line 785) | def __init__(
    method forward (line 869) | def forward(
  class AttnDownBlock2D (line 908) | class AttnDownBlock2D(nn.Module):
    method __init__ (line 909) | def __init__(
    method forward (line 1000) | def forward(
  class CrossAttnDownBlock2D (line 1031) | class CrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1032) | def __init__(
    method forward (line 1124) | def forward(
  class DownBlock2D (line 1193) | class DownBlock2D(nn.Module):
    method __init__ (line 1194) | def __init__(
    method forward (line 1245) | def forward(
  class DownEncoderBlock2D (line 1281) | class DownEncoderBlock2D(nn.Module):
    method __init__ (line 1282) | def __init__(
    method forward (line 1330) | def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0...
  class AttnDownEncoderBlock2D (line 1341) | class AttnDownEncoderBlock2D(nn.Module):
    method __init__ (line 1342) | def __init__(
    method forward (line 1413) | def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0...
  class AttnSkipDownBlock2D (line 1426) | class AttnSkipDownBlock2D(nn.Module):
    method __init__ (line 1427) | def __init__(
    method forward (line 1507) | def forward(
  class SkipDownBlock2D (line 1534) | class SkipDownBlock2D(nn.Module):
    method __init__ (line 1535) | def __init__(
    method forward (line 1594) | def forward(
  class ResnetDownsampleBlock2D (line 1619) | class ResnetDownsampleBlock2D(nn.Module):
    method __init__ (line 1620) | def __init__(
    method forward (line 1683) | def forward(
  class SimpleCrossAttnDownBlock2D (line 1719) | class SimpleCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1720) | def __init__(
    method forward (line 1814) | def forward(
  class KDownBlock2D (line 1879) | class KDownBlock2D(nn.Module):
    method __init__ (line 1880) | def __init__(
    method forward (line 1925) | def forward(
  class KCrossAttnDownBlock2D (line 1959) | class KCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1960) | def __init__(
    method forward (line 2024) | def forward(
  class AttnUpBlock2D (line 2086) | class AttnUpBlock2D(nn.Module):
    method __init__ (line 2087) | def __init__(
    method forward (line 2178) | def forward(
  class CrossAttnUpBlock2D (line 2206) | class CrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2207) | def __init__(
    method forward (line 2298) | def forward(
  class UpBlock2D (line 2380) | class UpBlock2D(nn.Module):
    method __init__ (line 2381) | def __init__(
    method forward (line 2430) | def forward(
  class UpDecoderBlock2D (line 2490) | class UpDecoderBlock2D(nn.Module):
    method __init__ (line 2491) | def __init__(
    method forward (line 2537) | def forward(
  class AttnUpDecoderBlock2D (line 2550) | class AttnUpDecoderBlock2D(nn.Module):
    method __init__ (line 2551) | def __init__(
    method forward (line 2621) | def forward(
  class AttnSkipUpBlock2D (line 2636) | class AttnSkipUpBlock2D(nn.Module):
    method __init__ (line 2637) | def __init__(
    method forward (line 2730) | def forward(
  class SkipUpBlock2D (line 2766) | class SkipUpBlock2D(nn.Module):
    method __init__ (line 2767) | def __init__(
    method forward (line 2838) | def forward(
  class ResnetUpsampleBlock2D (line 2871) | class ResnetUpsampleBlock2D(nn.Module):
    method __init__ (line 2872) | def __init__(
    method forward (line 2940) | def forward(
  class SimpleCrossAttnUpBlock2D (line 2980) | class SimpleCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2981) | def __init__(
    method forward (line 3079) | def forward(
  class KUpBlock2D (line 3146) | class KUpBlock2D(nn.Module):
    method __init__ (line 3147) | def __init__(
    method forward (line 3196) | def forward(
  class KCrossAttnUpBlock2D (line 3235) | class KCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 3236) | def __init__(
    method forward (line 3321) | def forward(
  class KAttentionBlock (line 3383) | class KAttentionBlock(nn.Module):
    method __init__ (line 3407) | def __init__(
    method _to_3d (line 3450) | def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight...
    method _to_4d (line 3453) | def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight...
    method forward (line 3456) | def forward(

FILE: foleycrafter/models/auffusion_unet.py
  class UNet2DConditionOutput (line 64) | class UNet2DConditionOutput(BaseOutput):
  class UNet2DConditionModel (line 76) | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
    method __init__ (line 173) | def __init__(
    method load_attention (line 639) | def load_attention(self):
    method get_writer_feature (line 652) | def get_writer_feature(self):
    method clear_writer_feature (line 655) | def clear_writer_feature(self):
    method disable_feature_adapters (line 658) | def disable_feature_adapters(self):
    method set_reader_feature (line 661) | def set_reader_feature(self, features: list):
    method attn_processors (line 665) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 688) | def set_attn_processor(
    method set_default_attn_processor (line 724) | def set_default_attn_processor(self):
    method set_attention_slice (line 739) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 804) | def _set_gradient_checkpointing(self, module, value=False):
    method enable_freeu (line 808) | def enable_freeu(self, s1, s2, b1, b2):
    method disable_freeu (line 832) | def disable_freeu(self):
    method fuse_qkv_projections (line 840) | def fuse_qkv_projections(self):
    method unfuse_qkv_projections (line 863) | def unfuse_qkv_projections(self):
    method forward (line 876) | def forward(

FILE: foleycrafter/models/onset/r2plus1d_18.py
  class r2plus1d18KeepTemp (line 9) | class r2plus1d18KeepTemp(nn.Module):
    method __init__ (line 10) | def __init__(self, pretrained=True):
    method forward (line 39) | def forward(self, x):

FILE: foleycrafter/models/onset/resnet.py
  class Conv3DSimple (line 15) | class Conv3DSimple(nn.Conv3d):
    method __init__ (line 16) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
    method get_downsample_stride (line 27) | def get_downsample_stride(stride):
  class Conv2Plus1D (line 31) | class Conv2Plus1D(nn.Sequential):
    method __init__ (line 32) | def __init__(self, in_planes, out_planes, midplanes, stride=1, padding...
    method get_downsample_stride (line 55) | def get_downsample_stride(stride):
  class Conv3DNoTemporal (line 59) | class Conv3DNoTemporal(nn.Conv3d):
    method __init__ (line 60) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
    method get_downsample_stride (line 71) | def get_downsample_stride(stride):
  class BasicBlock (line 75) | class BasicBlock(nn.Module):
    method __init__ (line 78) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
    method forward (line 90) | def forward(self, x):
  class Bottleneck (line 104) | class Bottleneck(nn.Module):
    method __init__ (line 107) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
    method forward (line 129) | def forward(self, x):
  class BasicStem (line 145) | class BasicStem(nn.Sequential):
    method __init__ (line 148) | def __init__(self):
  class R2Plus1dStem (line 156) | class R2Plus1dStem(nn.Sequential):
    method __init__ (line 159) | def __init__(self):
  class VideoResNet (line 170) | class VideoResNet(nn.Module):
    method __init__ (line 171) | def __init__(self, block, conv_makers, layers, stem, num_classes=400, ...
    method forward (line 202) | def forward(self, x):
    method _make_layer (line 221) | def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
    method _initialize_weights (line 239) | def _initialize_weights(self):
  function _video_resnet (line 253) | def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
  function r3d_18 (line 262) | def r3d_18(pretrained=False, progress=True, **kwargs):
  function mc3_18 (line 284) | def mc3_18(pretrained=False, progress=True, **kwargs):
  function r2plus1d_18 (line 305) | def r2plus1d_18(pretrained=False, progress=True, **kwargs):

FILE: foleycrafter/models/onset/torch_utils.py
  function load_model (line 14) | def load_model(cp_path, net, device=None, strict=True):
  function binary_acc (line 42) | def binary_acc(pred, target, threshold):
  function calc_acc (line 48) | def calc_acc(prob, labels, k):
  function get_dataloader (line 57) | def get_dataloader(args, pr, split="train", shuffle=False, drop_last=Fal...
  function make_optimizer (line 81) | def make_optimizer(model, args):
  function adjust_learning_rate (line 105) | def adjust_learning_rate(optimizer, epoch, args):

FILE: foleycrafter/models/onset/video_onset_net.py
  class VideoOnsetNet (line 9) | class VideoOnsetNet(nn.Module):
    method __init__ (line 11) | def __init__(self, pretrained):
    method forward (line 16) | def forward(self, inputs, loss=False, evaluate=False):

FILE: foleycrafter/models/time_detector/model.py
  class TimeDetector (line 6) | class TimeDetector(nn.Module):
    method __init__ (line 7) | def __init__(self, video_length=150, audio_length=1024):
    method forward (line 13) | def forward(self, inputs):

FILE: foleycrafter/models/time_detector/resnet.py
  class Conv3DSimple (line 14) | class Conv3DSimple(nn.Conv3d):
    method __init__ (line 15) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
    method get_downsample_stride (line 26) | def get_downsample_stride(stride):
  class Conv2Plus1D (line 30) | class Conv2Plus1D(nn.Sequential):
    method __init__ (line 31) | def __init__(self, in_planes, out_planes, midplanes, stride=1, padding...
    method get_downsample_stride (line 54) | def get_downsample_stride(stride):
  class Conv3DNoTemporal (line 58) | class Conv3DNoTemporal(nn.Conv3d):
    method __init__ (line 59) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
    method get_downsample_stride (line 70) | def get_downsample_stride(stride):
  class BasicBlock (line 74) | class BasicBlock(nn.Module):
    method __init__ (line 77) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
    method forward (line 89) | def forward(self, x):
  class Bottleneck (line 103) | class Bottleneck(nn.Module):
    method __init__ (line 106) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
    method forward (line 128) | def forward(self, x):
  class BasicStem (line 144) | class BasicStem(nn.Sequential):
    method __init__ (line 147) | def __init__(self):
  class R2Plus1dStem (line 155) | class R2Plus1dStem(nn.Sequential):
    method __init__ (line 158) | def __init__(self):
  class VideoResNet (line 169) | class VideoResNet(nn.Module):
    method __init__ (line 170) | def __init__(self, block, conv_makers, layers, stem, num_classes=400, ...
    method forward (line 201) | def forward(self, x):
    method _make_layer (line 220) | def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
    method _initialize_weights (line 238) | def _initialize_weights(self):
  function _video_resnet (line 252) | def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
  function r3d_18 (line 261) | def r3d_18(pretrained=False, progress=True, **kwargs):
  function mc3_18 (line 283) | def mc3_18(pretrained=False, progress=True, **kwargs):
  function r2plus1d_18 (line 304) | def r2plus1d_18(pretrained=False, progress=True, **kwargs):

FILE: foleycrafter/pipelines/auffusion_pipeline.py
  function json_dump (line 66) | def json_dump(data_json, json_save_path):
  function json_load (line 72) | def json_load(json_path):
  function import_model_class_from_model_name_or_path (line 79) | def import_model_class_from_model_name_or_path(pretrained_model_name_or_...
  class ConditionAdapter (line 99) | class ConditionAdapter(nn.Module):
    method __init__ (line 100) | def __init__(self, config):
    method forward (line 107) | def forward(self, x):
    method from_pretrained (line 113) | def from_pretrained(cls, pretrained_model_name_or_path):
    method save_pretrained (line 122) | def save_pretrained(self, pretrained_model_name_or_path):
  function rescale_noise_cfg (line 131) | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
  class AttrDict (line 149) | class AttrDict(dict):
    method __init__ (line 150) | def __init__(self, *args, **kwargs):
  function get_config (line 155) | def get_config(config_path):
  function init_weights (line 161) | def init_weights(m, mean=0.0, std=0.01):
  function apply_weight_norm (line 167) | def apply_weight_norm(m):
  function get_padding (line 173) | def get_padding(kernel_size, dilation=1):
  class ResBlock1 (line 177) | class ResBlock1(torch.nn.Module):
    method __init__ (line 178) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
    method forward (line 232) | def forward(self, x):
    method remove_weight_norm (line 241) | def remove_weight_norm(self):
  class ResBlock2 (line 248) | class ResBlock2(torch.nn.Module):
    method __init__ (line 249) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
    method forward (line 278) | def forward(self, x):
    method remove_weight_norm (line 285) | def remove_weight_norm(self):
  class Generator (line 290) | class Generator(torch.nn.Module):
    method __init__ (line 291) | def __init__(self, h):
    method device (line 347) | def device(self) -> torch.device:
    method dtype (line 351) | def dtype(self):
    method forward (line 354) | def forward(self, x):
    method remove_weight_norm (line 372) | def remove_weight_norm(self):
    method from_pretrained (line 382) | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
    method inference (line 398) | def inference(self, mels, lengths=None):
  function normalize_spectrogram (line 411) | def normalize_spectrogram(
  function denormalize_spectrogram (line 432) | def denormalize_spectrogram(
  function pt_to_numpy (line 457) | def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
  function numpy_to_pil (line 466) | def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
  function image_add_color (line 482) | def image_add_color(spec_img):
  class PipelineOutput (line 492) | class PipelineOutput(BaseOutput):
  class AuffusionPipeline (line 506) | class AuffusionPipeline(DiffusionPipeline):
    method __init__ (line 551) | def __init__(
    method from_pretrained (line 588) | def from_pretrained(
    method to (line 642) | def to(self, device, dtype=None):
    method enable_vae_slicing (line 656) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 665) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 672) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 681) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 688) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 713) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 742) | def _execution_device(self):
    method _encode_prompt (line 759) | def _encode_prompt(
    method run_safety_checker (line 862) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 876) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 889) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 906) | def check_inputs(
    method prepare_latents (line 953) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 971) | def __call__(
  function retrieve_timesteps (line 1120) | def retrieve_timesteps(
  class AuffusionNoAdapterPipeline (line 1164) | class AuffusionNoAdapterPipeline(
    method __init__ (line 1205) | def __init__(
    method enable_vae_slicing (line 1297) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 1304) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 1311) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 1319) | def disable_vae_tiling(self):
    method _encode_prompt (line 1326) | def _encode_prompt(
    method encode_prompt (line 1358) | def encode_prompt(
    method prepare_ip_adapter_image_embeds (line 1529) | def prepare_ip_adapter_image_embeds(
    method encode_image (line 1580) | def encode_image(self, image, device, num_images_per_prompt, output_hi...
    method run_safety_checker (line 1604) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 1618) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 1629) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 1646) | def check_inputs(
    method prepare_latents (line 1698) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method enable_freeu (line 1715) | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
    method disable_freeu (line 1737) | def disable_freeu(self):
    method fuse_qkv_projections (line 1742) | def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
    method unfuse_qkv_projections (line 1774) | def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
    method get_guidance_scale_embedding (line 1803) | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=tor...
    method guidance_scale (line 1832) | def guidance_scale(self):
    method guidance_rescale (line 1836) | def guidance_rescale(self):
    method clip_skip (line 1840) | def clip_skip(self):
    method do_classifier_free_guidance (line 1847) | def do_classifier_free_guidance(self):
    method cross_attention_kwargs (line 1851) | def cross_attention_kwargs(self):
    method num_timesteps (line 1855) | def num_timesteps(self):
    method interrupt (line 1859) | def interrupt(self):
    method __call__ (line 1863) | def __call__(

FILE: foleycrafter/pipelines/pipeline_controlnet.py
  function retrieve_timesteps (line 97) | def retrieve_timesteps(
  class StableDiffusionControlNetPipeline (line 141) | class StableDiffusionControlNetPipeline(
    method __init__ (line 186) | def __init__(
    method enable_vae_slicing (line 239) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 247) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 255) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 264) | def disable_vae_tiling(self):
    method _encode_prompt (line 272) | def _encode_prompt(
    method encode_prompt (line 305) | def encode_prompt(
    method prepare_ip_adapter_image_embeds (line 486) | def prepare_ip_adapter_image_embeds(
    method encode_image (line 538) | def encode_image(self, image, device, num_images_per_prompt, output_hi...
    method run_safety_checker (line 563) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 578) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 590) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 607) | def check_inputs(
    method check_image (line 753) | def check_image(self, image, prompt, prompt_embeds):
    method prepare_image (line 790) | def prepare_image(
    method prepare_latents (line 821) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method enable_freeu (line 839) | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
    method disable_freeu (line 862) | def disable_freeu(self):
    method get_guidance_scale_embedding (line 867) | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=tor...
    method guidance_scale (line 896) | def guidance_scale(self):
    method clip_skip (line 900) | def clip_skip(self):
    method do_classifier_free_guidance (line 907) | def do_classifier_free_guidance(self):
    method cross_attention_kwargs (line 911) | def cross_attention_kwargs(self):
    method num_timesteps (line 915) | def num_timesteps(self):
    method __call__ (line 920) | def __call__(

FILE: foleycrafter/utils/converter.py
  function load_wav (line 23) | def load_wav(full_path):
  function dynamic_range_compression (line 28) | def dynamic_range_compression(x, C=1, clip_val=1e-5):
  function dynamic_range_decompression (line 32) | def dynamic_range_decompression(x, C=1):
  function dynamic_range_compression_torch (line 36) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
  function dynamic_range_decompression_torch (line 40) | def dynamic_range_decompression_torch(x, C=1):
  function spectral_normalize_torch (line 44) | def spectral_normalize_torch(magnitudes):
  function spectral_de_normalize_torch (line 49) | def spectral_de_normalize_torch(magnitudes):
  function mel_spectrogram (line 58) | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_siz...
  function spectrogram (line 97) | def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, f...
  function normalize_spectrogram (line 130) | def normalize_spectrogram(
  function denormalize_spectrogram (line 161) | def denormalize_spectrogram(
  function get_mel_spectrogram_from_audio (line 195) | def get_mel_spectrogram_from_audio(audio, device="cpu"):
  class AttrDict (line 222) | class AttrDict(dict):
    method __init__ (line 223) | def __init__(self, *args, **kwargs):
  function get_config (line 228) | def get_config(config_path):
  function init_weights (line 234) | def init_weights(m, mean=0.0, std=0.01):
  function apply_weight_norm (line 240) | def apply_weight_norm(m):
  function get_padding (line 246) | def get_padding(kernel_size, dilation=1):
  class ResBlock1 (line 250) | class ResBlock1(torch.nn.Module):
    method __init__ (line 251) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
    method forward (line 305) | def forward(self, x):
    method remove_weight_norm (line 314) | def remove_weight_norm(self):
  class ResBlock2 (line 321) | class ResBlock2(torch.nn.Module):
    method __init__ (line 322) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
    method forward (line 351) | def forward(self, x):
    method remove_weight_norm (line 358) | def remove_weight_norm(self):
  class Generator (line 363) | class Generator(torch.nn.Module):
    method __init__ (line 364) | def __init__(self, h):
    method forward (line 416) | def forward(self, x):
    method remove_weight_norm (line 434) | def remove_weight_norm(self):
    method from_pretrained (line 443) | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
    method inference (line 459) | def inference(self, mels, lengths=None):
  function normalize (line 472) | def normalize(images):
  function pad_spec (line 482) | def pad_spec(spec, spec_length, pad_value=0, random_crop=True):  # spec:...

FILE: foleycrafter/utils/spec_to_mel.py
  class STFT (line 14) | class STFT(torch.nn.Module):
    method __init__ (line 17) | def __init__(self, filter_length, hop_length, win_length, window="hann"):
    method transform (line 47) | def transform(self, input_data):
    method inverse (line 78) | def inverse(self, magnitude, phase):
    method forward (line 111) | def forward(self, input_data):
  function window_sumsquare (line 117) | def window_sumsquare(
  function griffin_lim (line 176) | def griffin_lim(magnitudes, stft_fn, n_iters=30):
  function dynamic_range_compression (line 195) | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=...
  function dynamic_range_decompression (line 204) | def dynamic_range_decompression(x, C=1):
  class TacotronSTFT (line 213) | class TacotronSTFT(torch.nn.Module):
    method __init__ (line 214) | def __init__(
    method spectral_normalize (line 232) | def spectral_normalize(self, magnitudes, normalize_fun):
    method spectral_de_normalize (line 236) | def spectral_de_normalize(self, magnitudes):
    method mel_spectrogram (line 240) | def mel_spectrogram(self, y, normalize_fun=torch.log):
  function pad_wav (line 264) | def pad_wav(waveform, segment_length):
  function normalize_wav (line 277) | def normalize_wav(waveform):
  function _pad_spec (line 283) | def _pad_spec(fbank, target_length=1024):
  function get_mel_from_wav (line 299) | def get_mel_from_wav(audio, _stft):
  function read_wav_file_io (line 309) | def read_wav_file_io(bytes):
  function load_audio (line 323) | def load_audio(bytes, sample_rate=16000):
  function read_wav_file (line 329) | def read_wav_file(filename):
  function norm_wav_tensor (line 343) | def norm_wav_tensor(waveform: torch.FloatTensor):
  function wav_to_fbank (line 352) | def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
  function wav_tensor_to_fbank (line 380) | def wav_tensor_to_fbank(waveform, target_length=512, fn_STFT=None):

FILE: foleycrafter/utils/util.py
  function zero_rank_print (line 35) | def zero_rank_print(s):
  function build_foleycrafter (line 40) | def build_foleycrafter(
  function save_videos_grid (line 66) | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_r...
  function save_videos_from_pil_list (line 82) | def save_videos_from_pil_list(videos: list, path: str, fps=7):
  function seed_everything (line 89) | def seed_everything(seed: int) -> None:
  function get_video_frames (line 102) | def get_video_frames(video: np.ndarray, num_frames: int = 200):
  function random_audio_video_clip (line 109) | def random_audio_video_clip(
  function get_full_indices (line 142) | def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReade...
  function get_frames (line 149) | def get_frames(video_path: str, onset_list, frame_nums=1024):
  function get_frames_in_video (line 163) | def get_frames_in_video(video_path: str, onset_list, frame_nums=1024, au...
  function save_multimodal (line 183) | def save_multimodal(video, audio, output_path, audio_fps: int = 16000, v...
  function save_multimodal_by_frame (line 203) | def save_multimodal_by_frame(video, audio, output_path, audio_fps: int =...
  function sanity_check (line 221) | def sanity_check(data: dict, save_path: str = "sanity_check", batch_size...
  function video_tensor_to_np (line 242) | def video_tensor_to_np(video: torch.Tensor, rescale: bool = True, scale:...
  function composite_audio_video (line 255) | def composite_audio_video(video: str, audio: str, path: str, video_fps: ...
  function append_dims (line 265) | def append_dims(x, target_dims):
  function resize_with_antialiasing (line 273) | def resize_with_antialiasing(input, size, interpolation="bicubic", align...
  function _gaussian_blur2d (line 302) | def _gaussian_blur2d(input, kernel_size, sigma):
  function _filter2d (line 318) | def _filter2d(input, kernel):
  function _gaussian (line 341) | def _gaussian(window_size: int, sigma):
  function _compute_padding (line 357) | def _compute_padding(kernel_size):
  function print_gpu_memory_usage (line 380) | def print_gpu_memory_usage(info: str, cuda_id: int = 0):
  class SpectrogramParams (line 392) | class SpectrogramParams:
    class ExifTags (line 427) | class ExifTags(Enum):
    method n_fft (line 446) | def n_fft(self) -> int:
    method win_length (line 453) | def win_length(self) -> int:
    method hop_length (line 460) | def hop_length(self) -> int:
    method to_exif (line 466) | def to_exif(self) -> T.Dict[int, T.Any]:
  class SpectrogramImageConverter (line 483) | class SpectrogramImageConverter:
    method __init__ (line 491) | def __init__(self, params: SpectrogramParams, device: str = "cuda"):
    method spectrogram_image_from_audio (line 496) | def spectrogram_image_from_audio(
    method audio_from_spectrogram_image (line 538) | def audio_from_spectrogram_image(
  function image_from_spectrogram (line 567) | def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25)...
  function spectrogram_from_image (line 613) | def spectrogram_from_image(
  class SpectrogramConverter (line 667) | class SpectrogramConverter:
    method __init__ (line 689) | def __init__(self, params: SpectrogramParams, device: str = "cuda"):
    method spectrogram_from_audio (line 756) | def spectrogram_from_audio(
    method audio_from_spectrogram (line 782) | def audio_from_spectrogram(
    method mel_amplitudes_from_waveform (line 820) | def mel_amplitudes_from_waveform(
    method waveform_from_mel_amplitudes (line 842) | def waveform_from_mel_amplitudes(
  function check_device (line 862) | def check_device(device: str, backup: str = "cpu") -> str:
  function audio_from_waveform (line 876) | def audio_from_waveform(samples: np.ndarray, sample_rate: int, normalize...
  function apply_filters_func (line 900) | def apply_filters_func(segment: pydub.AudioSegment, compression: bool = ...
  function shave_segments (line 936) | def shave_segments(path, n_shave_prefix_segments=1):
  function renew_resnet_paths (line 946) | def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
  function renew_vae_resnet_paths (line 968) | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
  function renew_attention_paths (line 984) | def renew_attention_paths(old_list, n_shave_prefix_segments=0):
  function renew_vae_attention_paths (line 1005) | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
  function assign_to_checkpoint (line 1034) | def assign_to_checkpoint(
  function conv_attn_to_linear (line 1089) | def conv_attn_to_linear(checkpoint):
  function create_unet_diffusers_config (line 1101) | def create_unet_diffusers_config(original_config, image_size: int, contr...
  function create_vae_diffusers_config (line 1170) | def create_vae_diffusers_config(original_config, image_size: int):
  function create_diffusers_schedular (line 1194) | def create_diffusers_schedular(original_config):
  function convert_ldm_unet_checkpoint (line 1204) | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_e...
  function convert_ldm_vae_checkpoint (line 1435) | def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, o...
  function convert_ldm_clip_checkpoint (line 1550) | def convert_ldm_clip_checkpoint(checkpoint):
  function convert_lora_model_level (line 1561) | def convert_lora_model_level(
  function denormalize_spectrogram (line 1643) | def denormalize_spectrogram(
  class ToTensor1D (line 1677) | class ToTensor1D(torchvision.transforms.ToTensor):
    method __call__ (line 1678) | def __call__(self, tensor: np.ndarray):
  function scale (line 1684) | def scale(old_value, old_min, old_max, new_min, new_max):
  function read_frames_with_moviepy (line 1692) | def read_frames_with_moviepy(video_path, max_frame_nums=None):
  function read_frames_with_moviepy_resample (line 1703) | def read_frames_with_moviepy_resample(video_path, save_path):

FILE: inference.py
  function args_parse (line 28) | def args_parse():
  function build_models (line 49) | def build_models(config):
  function run_inference (line 97) | def run_inference(config, pipe, vocoder, time_detector):
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (919K chars).
[
  {
    "path": ".gitignore",
    "chars": 153,
    "preview": "*.ckpt\n*.pt\n*.pyc\n*.safetensors\n\n__pycache__/\noutput/\ncheckpoints/\ntrain/\nconfigs/\n\n*.wav\n*.mp3\n*.gif\n*.jpg\n*.png\n*.log\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 651,
    "preview": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    # Ruff version.\n    rev: v0.3.5\n    hooks:\n      # Run"
  },
  {
    "path": "LICENSE",
    "chars": 11355,
    "preview": "                               Apache License\n                           Version 2.0, January 2004\n                     "
  },
  {
    "path": "README.md",
    "chars": 18131,
    "preview": "<p align=\"center\">\n<img src='assets/foleycrafter.png' style=\"text-align: center; width: 134px\" >\n</p>\n\n<div align=\"cente"
  },
  {
    "path": "app.py",
    "chars": 13846,
    "preview": "import os\nimport os.path as osp\nimport random\nfrom argparse import ArgumentParser\nfrom datetime import datetime\n\nimport "
  },
  {
    "path": "foleycrafter/data/__init__.py",
    "chars": 656,
    "preview": "from .dataset import AudioSetStrong, CPU_Unpickler, VGGSound, dynamic_range_compression, get_mel, zero_rank_print\nfrom ."
  },
  {
    "path": "foleycrafter/data/dataset.py",
    "chars": 4801,
    "preview": "import glob\nimport io\nimport pickle\nimport random\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimpo"
  },
  {
    "path": "foleycrafter/data/video_transforms.py",
    "chars": 12623,
    "preview": "import numbers\nimport random\n\nimport torch\n\n\ndef _is_tensor_video_clip(clip):\n    if not torch.is_tensor(clip):\n        "
  },
  {
    "path": "foleycrafter/models/adapters/attention_processor.py",
    "chars": 24309,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diffusers.utils import logging\n\n\nlogger = loggi"
  },
  {
    "path": "foleycrafter/models/adapters/ip_adapter.py",
    "chars": 7195,
    "preview": "import torch\nimport torch.nn as nn\n\n\nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n\n    def __init__(self, unet"
  },
  {
    "path": "foleycrafter/models/adapters/resampler.py",
    "chars": 5059,
    "preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
  },
  {
    "path": "foleycrafter/models/adapters/transformer.py",
    "chars": 13691,
    "preview": "from typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\n\nclass Attention(n"
  },
  {
    "path": "foleycrafter/models/adapters/utils.py",
    "chars": 2501,
    "preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\n\n\nattn_maps = {}\n\n\ndef hook_fn(nam"
  },
  {
    "path": "foleycrafter/models/auffusion/attention.py",
    "chars": 27845,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion/attention_processor.py",
    "chars": 116323,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion/dual_transformer_2d.py",
    "chars": 7688,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion/loaders/ip_adapter.py",
    "chars": 27447,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion/loaders/unet.py",
    "chars": 46482,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion/resnet.py",
    "chars": 26825,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The"
  },
  {
    "path": "foleycrafter/models/auffusion/transformer_2d.py",
    "chars": 24001,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion/unet_2d_blocks.py",
    "chars": 135768,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/models/auffusion_unet.py",
    "chars": 65826,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "foleycrafter/models/onset/__init__.py",
    "chars": 139,
    "preview": "from .r2plus1d_18 import r2plus1d18KeepTemp\nfrom .video_onset_net import VideoOnsetNet\n\n\n__all__ = [\"r2plus1d18KeepTemp\""
  },
  {
    "path": "foleycrafter/models/onset/r2plus1d_18.py",
    "chars": 1885,
    "preview": "# Copied from specvqgan/onset_baseline/models/r2plus1d_18.py\n\nimport torch\nimport torch.nn as nn\n\nfrom .resnet import r2"
  },
  {
    "path": "foleycrafter/models/onset/resnet.py",
    "chars": 10499,
    "preview": "# Copied from specvqgan/onset_baseline/models/resnet.py\nimport torch.nn as nn\nfrom torch.hub import load_state_dict_from"
  },
  {
    "path": "foleycrafter/models/onset/torch_utils.py",
    "chars": 3529,
    "preview": "# Copied from https://github.com/XYPB/CondFoleyGen/blob/main/specvqgan/onset_baseline/utils/torch_utils.py\nimport os\nimp"
  },
  {
    "path": "foleycrafter/models/onset/video_onset_net.py",
    "chars": 883,
    "preview": "# Copied from specvqgan/onset_baseline/models/video_onset_net.py\n\nimport torch\nimport torch.nn as nn\n\nfrom .r2plus1d_18 "
  },
  {
    "path": "foleycrafter/models/time_detector/model.py",
    "chars": 491,
    "preview": "import torch.nn as nn\n\nfrom ..onset import VideoOnsetNet\n\n\nclass TimeDetector(nn.Module):\n    def __init__(self, video_l"
  },
  {
    "path": "foleycrafter/models/time_detector/resnet.py",
    "chars": 10443,
    "preview": "import torch.nn as nn\nfrom torch.hub import load_state_dict_from_url\n\n\n__all__ = [\"r3d_18\", \"mc3_18\", \"r2plus1d_18\"]\n\nmo"
  },
  {
    "path": "foleycrafter/pipelines/auffusion_pipeline.py",
    "chars": 97874,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/pipelines/pipeline_controlnet.py",
    "chars": 68360,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "foleycrafter/utils/converter.py",
    "chars": 14768,
    "preview": "# Copy from https://github.com/happylittlecat2333/Auffusion/blob/main/converter.py\nimport json\nimport os\nimport random\n\n"
  },
  {
    "path": "foleycrafter/utils/spec_to_mel.py",
    "chars": 13197,
    "preview": "import librosa.util as librosa_util\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchaudio\nfr"
  },
  {
    "path": "foleycrafter/utils/util.py",
    "chars": 65887,
    "preview": "import glob\nimport io\nimport os\nimport os.path as osp\nimport random\nimport typing as T\nimport warnings\nfrom dataclasses "
  },
  {
    "path": "inference.py",
    "chars": 7717,
    "preview": "import argparse\nimport glob\nimport os\nimport os.path as osp\nfrom pathlib import Path\n\nimport soundfile as sf\nimport torc"
  },
  {
    "path": "pyproject.toml",
    "chars": 749,
    "preview": "[tool.ruff]\n# Never enforce `E501` (line length violations).\nignore = [\"C901\", \"E501\", \"E741\", \"F402\", \"F823\"]\nselect = "
  },
  {
    "path": "requirements/environment.yaml",
    "chars": 444,
    "preview": "name: foleycrafter\nchannels:\n  - pytorch\n  - nvidia\ndependencies:\n  - python=3.10\n  - pytorch=2.2.0\n  - torchvision=0.17"
  }
]

About this extraction

This page contains the full source code of the open-mmlab/FoleyCrafter GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (869.2 KB), approximately 201.2k tokens, and a symbol index with 700 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!