Showing preview only (897K chars total). Download the full file or copy to clipboard to get everything.
Repository: open-mmlab/FoleyCrafter
Branch: main
Commit: c2399fa3cb7e
Files: 37
Total size: 869.2 KB
Directory structure:
gitextract_36et0ept/
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── app.py
├── foleycrafter/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── video_transforms.py
│ ├── models/
│ │ ├── adapters/
│ │ │ ├── attention_processor.py
│ │ │ ├── ip_adapter.py
│ │ │ ├── resampler.py
│ │ │ ├── transformer.py
│ │ │ └── utils.py
│ │ ├── auffusion/
│ │ │ ├── attention.py
│ │ │ ├── attention_processor.py
│ │ │ ├── dual_transformer_2d.py
│ │ │ ├── loaders/
│ │ │ │ ├── ip_adapter.py
│ │ │ │ └── unet.py
│ │ │ ├── resnet.py
│ │ │ ├── transformer_2d.py
│ │ │ └── unet_2d_blocks.py
│ │ ├── auffusion_unet.py
│ │ ├── onset/
│ │ │ ├── __init__.py
│ │ │ ├── r2plus1d_18.py
│ │ │ ├── resnet.py
│ │ │ ├── torch_utils.py
│ │ │ └── video_onset_net.py
│ │ └── time_detector/
│ │ ├── model.py
│ │ └── resnet.py
│ ├── pipelines/
│ │ ├── auffusion_pipeline.py
│ │ └── pipeline_controlnet.py
│ └── utils/
│ ├── converter.py
│ ├── spec_to_mel.py
│ └── util.py
├── inference.py
├── pyproject.toml
└── requirements/
└── environment.yaml
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.ckpt
*.pt
*.pyc
*.safetensors
__pycache__/
output/
checkpoints/
train/
configs/
*.wav
*.mp3
*.gif
*.jpg
*.png
*.log
*.ckpt
*.json
*.csv
*.txt
*.bin
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.5
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<p align="center">
<img src='assets/foleycrafter.png' style="text-align: center; width: 134px" >
</p>
<div align="center">
[](https://arxiv.org/abs/2407.01494)
[](https://foleycrafter.github.io)
<a target="_blank" href="https://huggingface.co/spaces/ymzhang319/FoleyCrafter">
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
</a>
[](https://huggingface.co/ymzhang319/FoleyCrafter)
[](https://openxlab.org.cn/apps/detail/ymzhang319/FoleyCrafter)
</div>
# FoleyCrafter
Sound effects are the unsung heroes of cinema and gaming, enhancing realism, impact, and emotional depth for an immersive audiovisual experience. **FoleyCrafter** is a video-to-audio generation framework which can produce realistic sound effects semantically relevant and synchronized with videos.
**Your star is our fuel! <img alt="" width="30" src="https://camo.githubusercontent.com/2f4f0d02cdf79dc1ff8d2b053b4410b13bc2e39cbc8a96fcdc6f06538a3d6d2b/68747470733a2f2f656d2d636f6e74656e742e7a6f626a2e6e65742f736f757263652f616e696d617465642d6e6f746f2d636f6c6f722d656d6f6a692f3335362f736d696c696e672d666163652d776974682d6865617274735f31663937302e676966"> We're revving up the engines with it! <img alt="" width="30" src="https://camo.githubusercontent.com/028a75f875b8c3aa1b3c80bbf7dd27973c4bb654fffcf0bdc0b6f1b0674ce481/68747470733a2f2f656d2d636f6e74656e742e7a6f626a2e6e65742f736f757263652f74656c656772616d2f3338362f737061726b6c65735f323732382e77656270">**
[FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds]()
[Yiming Zhang](https://github.com/ymzhang0319),
[Yicheng Gu](https://github.com/VocodexElysium),
[Yanhong Zeng†](https://zengyh1900.github.io/),
[Zhening Xing](https://github.com/LeoXing1996/),
[Yuancheng Wang](https://github.com/HeCheng0625),
[Zhizheng Wu](https://drwuz.com/),
[Kai Chen†](https://chenkai.site/)
(†Corresponding Author)
## What's New
- [ ] A more powerful one :stuck_out_tongue_closed_eyes: .
- [ ] Release training code.
- [x] `2024/07/01` Release the model and code of FoleyCrafter.
## Setup
### Prepare Environment
Use the following command to install dependencies:
```bash
# install conda environment
conda env create -f requirements/environment.yaml
conda activate foleycrafter
# install GIT LFS for checkpoints download
conda install git-lfs
git lfs install
```
### Download Checkpoints
The checkpoints will be downloaded automatically by running `inference.py`.
You can also download manually using following commands.
<li> Download the text-to-audio base model. We use Auffusion</li>
```bash
git clone https://huggingface.co/auffusion/auffusion-full-no-adapter checkpoints/auffusion
```
<li> Download FoleyCrafter</li>
```bash
git clone https://huggingface.co/ymzhang319/FoleyCrafter checkpoints/
```
Put checkpoints as follows:
```
└── checkpoints
├── semantic
│ ├── semantic_adapter.bin
├── vocoder
│ ├── vocoder.pt
│ ├── config.json
├── temporal_adapter.ckpt
│ │
└── timestamp_detector.pth.tar
```
## Gradio demo
You can launch the Gradio interface for FoleyCrafter by running the following command:
```bash
python app.py --share
```
## Inference
### Video To Audio Generation
```bash
python inference.py --save_dir=output/sora/
```
Results:
<table class='center'>
<tr>
<td><p style="text-align: center">Input Video</p></td>
<td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309262-d7c89984-c567-4ca7-8e2d-8f49d84bda4a.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122032Z&X-Amz-Expires=300&X-Amz-Signature=5b13f216056dedca2705233038dbb22f73023d2c1deaf3b03972d7b91c1bbab5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309725-0dfa72a2-1466-46e6-9611-3e1cbff707fe.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122123Z&X-Amz-Expires=300&X-Amz-Signature=314648ed216620b2d926395d34602c70da500eb9e865e839de6907ed1b0d0bd1&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309166-16206bb8-9c5e-4e9d-9d73-bc251e5658fd.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122143Z&X-Amz-Expires=300&X-Amz-Signature=43c3e2c687846eb3ba118237628b747a78c403ea4f21739fe2d423724f7b426c&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309768-90c42af6-0d24-4a05-98d4-64e23467c4bb.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122213Z&X-Amz-Expires=300&X-Amz-Signature=cfed43cd2710bf73b84b6c3ebe8debd1e0b098bdc24a1a14f6531499d01c278e&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309601-e711b7c5-1614-4d39-8b1e-c54e28eec809.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122221Z&X-Amz-Expires=300&X-Amz-Signature=4c4680eb6c541433e4505fb2b5f5a5cc8d3e5708d9f2675a98cbb556cd5d59f5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309802-2db7f130-0c25-45c2-ad4d-bf86c5468b1f.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122243Z&X-Amz-Expires=300&X-Amz-Signature=2c32069318f60f03ee9a3185a7e2833c534ea601a02a5ff025b46c2abbc5b120&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309637-6c2f106d-6b98-41ac-80ba-734636321f8c.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122305Z&X-Amz-Expires=300&X-Amz-Signature=04adb2cd80785a245ce704837ba9932d3646d375c51c53e7b70e6861ec7f6b4a&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342309836-77391524-9b31-4602-ad42-0876e0c16794.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122317Z&X-Amz-Expires=300&X-Amz-Signature=3020c12591592a106efcb1aaa22237093737afaa55ede3880d3cbf9cd80b7482&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
</table>
- Temporal Alignment with Visual Cues
```bash
python inference.py \
--temporal_align \
--input=input/avsync \
--save_dir=output/avsync/
```
Results:
<table class='center'>
<tr>
<td><p style="text-align: center">Ground Truth</p></td>
<td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310778-bcc0f16d-6d1b-468d-a775-81b8f2d98ea6.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122327Z&X-Amz-Expires=300&X-Amz-Signature=205b8e190a428b3ddee41fe2549080b4f50fd8bb10ef78d650fc05add85ccbab&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310418-8433e05c-8600-4cd6-8a68-ead536159204.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122337Z&X-Amz-Expires=300&X-Amz-Signature=3fc37a305511c1c8b7bdfc9b9b5bd0485fd584400af087939a4c08218ab33538&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310801-3d6fd80d-de6b-4815-ac6a-f81772709e4c.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122349Z&X-Amz-Expires=300&X-Amz-Signature=fc31f139a1f9c7606657fa457f1271a8a44cb39a13454939c030ccdafe2d3068&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310491-dfaf41e7-487e-47ff-8e8a-fe7cb4fb1942.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122356Z&X-Amz-Expires=300&X-Amz-Signature=6353a935194a08bc081fa873e3c6582fb175874d3112f8f8f96614a5e542ef03&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310825-6834f00f-95e8-4a2c-b864-b4fe57801836.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122406Z&X-Amz-Expires=300&X-Amz-Signature=c67b2f113b0db790a8495d1a4ab4c0d230db5f53a9062a497bcca3e57f9600aa&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342310543-5a2c363b-623c-4329-be0e-a151e5bb56a6.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122415Z&X-Amz-Expires=300&X-Amz-Signature=d8b0bfc28716e0e03694b3e590aca29450fa7788aa39e748130ee12a15d614e9&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
</table>
### Text-based Video to Audio Generation
- Using Prompt
```bash
# case1
python inference.py \
--input=input/PromptControl/case1/ \
--seed=10201304011203481429 \
--save_dir=output/PromptControl/case1/
python inference.py \
--input=input/PromptControl/case1/ \
--seed=10201304011203481429 \
--prompt='noisy, people talking' \
--save_dir=output/PromptControl/case1_prompt/
# case2
python inference.py \
--input=input/PromptControl/case2/ \
--seed=10021049243103289113 \
--save_dir=output/PromptControl/case2/
python inference.py \
--input=input/PromptControl/case2/ \
--seed=10021049243103289113 \
--prompt='seagulls' \
--save_dir=output/PromptControl/case2_prompt/
```
Results:
<table class='center'>
<tr>
<td><p style="text-align: center">Generated Audio</p></td>
<td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
<td><p style="text-align: center">Without Prompt</p></td>
<td><p style="text-align: center">Prompt: <b>noisy, people talking</b></p></td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311425-8dd543cb-0df2-441e-b6d0-86048dbeb73d.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122426Z&X-Amz-Expires=300&X-Amz-Signature=b872b8eaf51a5022aee1daf0283d92e53a70e109c6b9f1e6a4da238a3708ea45&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311493-62a08024-581c-4716-a030-aef194beddc5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122439Z&X-Amz-Expires=300&X-Amz-Signature=647eb0dc32bf7c0d739ccbe875826b1a67f54e0dc84e0be70e0e128ae2fdb73d&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td><p style="text-align: center">Without Prompt</p></td>
<td><p style="text-align: center">Prompt: <b>seagulls</b></p></td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311538-1f81f91e-efc0-41ed-bdcb-c5c6ff976c5b.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122447Z&X-Amz-Expires=300&X-Amz-Signature=d19761788775893e77e42b9312b4c26cb85aedd7c6dc249eaf68ff1f650e1942&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311595-695668ed-46a1-47b2-b5fd-3aa4286d695e.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122500Z&X-Amz-Expires=300&X-Amz-Signature=bf7d69ab8c74154ee8ac5682f64ce29a310ea2d0365f620893a45899d62a3f80&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
</table>
- Using Negative Prompt
```bash
# case 3
python inference.py \
--input=input/PromptControl/case3/ \
--seed=10041042941301238011 \
--save_dir=output/PromptControl/case3/
python inference.py \
--input=input/PromptControl/case3/ \
--seed=10041042941301238011 \
--nprompt='river flows' \
--save_dir=output/PromptControl/case3_nprompt/
# case4
python inference.py \
--input=input/PromptControl/case4/ \
--seed=10014024412012338096 \
--save_dir=output/PromptControl/case4/
python inference.py \
--input=input/PromptControl/case4/ \
--seed=10014024412012338096 \
--nprompt='noisy, wind noise' \
--save_dir=output/PromptControl/case4_nprompt/
```
Results:
<table class='center'>
<tr>
<td><p style="text-align: center">Generated Audio</p></td>
<td><p style="text-align: center">Generated Audio</p></td>
<tr>
<tr>
<td><p style="text-align: center">Without Prompt</p></td>
<td><p style="text-align: center">Negative Prompt: <b>river flows</b></p></td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311656-cdc69cf1-88f8-4861-b888-bdb82358b9c5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122517Z&X-Amz-Expires=300&X-Amz-Signature=1731e655b2f0bb7f4a7af737ee065a01f98fccb1c54ef48ae775ad65ec67eda5&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311702-cd259522-84f4-44cb-862f-c4dcfb57e5c4.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122527Z&X-Amz-Expires=300&X-Amz-Signature=d846ac60df25ff18de1861daeff380b7bc8ca21c04e7dce139ed44abbf9aaa22&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
<tr>
<td><p style="text-align: center">Without Prompt</p></td>
<td><p style="text-align: center">Negative Prompt: <b>noisy, wind noise</b></p></td>
<tr>
<tr>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311785-5ca9c050-a928-4dc2-b620-d843a3ae72f5.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122533Z&X-Amz-Expires=300&X-Amz-Signature=151fe4da521f9f48ff245ef5bd7c6964f1dfc652be0fe8de4c151ba59e87d2d6&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<td>
https://github-production-user-asset-6210df.s3.amazonaws.com/134203169/342311844-28d6abe3-d5a8-4a7f-9f4d-3cc8411affba.mp4?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240624%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240624T122544Z&X-Amz-Expires=300&X-Amz-Signature=39a259ca76a12a57d47ad74c8cf92af12af9e3db34084f5f82c79b6a62356e9a&X-Amz-SignedHeaders=host&actor_id=134203169&key_id=0&repo_id=812946188
</td>
<tr>
</table>
### Commandline Usage Parameters
```console
options:
-h, --help show this help message and exit
--prompt PROMPT prompt for audio generation
--nprompt NPROMPT negative prompt for audio generation
--seed SEED ramdom seed
--temporal_align TEMPORAL_ALIGN
use temporal adapter or not
--temporal_scale TEMPORAL_SCALE
temporal align scale
--semantic_scale SEMANTIC_SCALE
visual content scale
--input INPUT input video folder path
--ckpt CKPT checkpoints folder path
--save_dir SAVE_DIR generation result save path
--pretrain PRETRAIN generator checkpoint path
--device DEVICE
```
## BibTex
```
@misc{zhang2024pia,
title={FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds},
author={Yiming Zhang, Yicheng Gu, Yanhong Zeng, Zhening Xing, Yuancheng Wang, Zhizheng Wu, Kai Chen},
year={2024},
eprint={2407.01494},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Contact Us
**Yiming Zhang**: zhangyiming@pjlab.org.cn
**YiCheng Gu**: yichenggu@link.cuhk.edu.cn
**Yanhong Zeng**: zengyanhong@pjlab.org.cn
## LICENSE
Please check [LICENSE](./LICENSE) for the part of FoleyCrafter for details.
If you are using it for commercial purposes, please check the license of the [Auffusion](https://github.com/happylittlecat2333/Auffusion).
## Acknowledgements
The code is built upon [Auffusion](https://github.com/happylittlecat2333/Auffusion), [CondFoleyGen](https://github.com/XYPB/CondFoleyGen) and [SpecVQGAN](https://github.com/v-iashin/SpecVQGAN).
We recommend a toolkit for Audio, Music, and Speech Generation [Amphion](https://github.com/open-mmlab/Amphion) :gift_heart:.
================================================
FILE: app.py
================================================
import os
import os.path as osp
import random
from argparse import ArgumentParser
from datetime import datetime
import gradio as gr
import soundfile as sf
import torch
import torchvision
from huggingface_hub import snapshot_download
from moviepy.editor import AudioFileClip, VideoFileClip
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
from foleycrafter.models.onset import torch_utils
from foleycrafter.models.time_detector.model import VideoOnsetNet
from foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram
from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
sample_idx = 0
scheduler_dict = {
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
"PNDM": PNDMScheduler,
}
css = """
.toolbutton {
margin-buttom: 0em 0em 0em 0em;
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.5em;
}
"""
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="example/config/base.yaml")
parser.add_argument("--server-name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", type=bool, default=False)
parser.add_argument("--save-path", default="samples")
parser.add_argument("--ckpt", type=str, default="checkpoints/")
args = parser.parse_args()
N_PROMPT = ""
class FoleyController:
def __init__(self):
# config dirs
self.basedir = os.getcwd()
self.model_dir = os.path.join(self.basedir, args.ckpt)
self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
self.savedir_sample = os.path.join(self.savedir, "sample")
os.makedirs(self.savedir, exist_ok=True)
self.pipeline = None
self.loaded = False
self.load_model()
def load_model(self):
gr.Info("Start Load Models...")
print("Start Load Models...")
# download ckpt
pretrained_model_name_or_path = "auffusion/auffusion-full-no-adapter"
if not os.path.isdir(pretrained_model_name_or_path):
pretrained_model_name_or_path = snapshot_download(
pretrained_model_name_or_path, local_dir=osp.join(self.model_dir, "auffusion")
)
fc_ckpt = "ymzhang319/FoleyCrafter"
if not os.path.isdir(fc_ckpt):
fc_ckpt = snapshot_download(fc_ckpt, local_dir=self.model_dir)
# set model config
temporal_ckpt_path = osp.join(self.model_dir, "temporal_adapter.ckpt")
# load vocoder
vocoder_config_path = osp.join(self.model_dir, "auffusion")
self.vocoder = Generator.from_pretrained(vocoder_config_path, subfolder="vocoder")
# load time detector
time_detector_ckpt = osp.join(osp.join(self.model_dir, "timestamp_detector.pth.tar"))
time_detector = VideoOnsetNet(False)
self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)
self.pipeline = build_foleycrafter()
ckpt = torch.load(temporal_ckpt_path)
# load temporal adapter
if "state_dict" in ckpt.keys():
ckpt = ckpt["state_dict"]
load_gligen_ckpt = {}
for key, value in ckpt.items():
if key.startswith("module."):
load_gligen_ckpt[key[len("module.") :]] = value
else:
load_gligen_ckpt[key] = value
m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
self.image_processor = CLIPImageProcessor()
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter", subfolder="models/image_encoder"
)
self.pipeline.load_ip_adapter(
fc_ckpt, subfolder="semantic", weight_name="semantic_adapter.bin", image_encoder_folder=None
)
gr.Info("Load Finish!")
print("Load Finish!")
self.loaded = True
return "Load"
def foley(
self,
input_video,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
):
device = "cuda"
# move to gpu
self.time_detector = controller.time_detector.to(device)
self.pipeline = controller.pipeline.to(device)
self.vocoder = controller.vocoder.to(device)
self.image_encoder = controller.image_encoder.to(device)
vision_transform_list = [
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.CenterCrop((112, 112)),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
video_transform = torchvision.transforms.Compose(vision_transform_list)
# if not self.loaded:
# raise gr.Error("Error with loading model")
generator = torch.Generator()
if seed_textbox != "":
torch.manual_seed(int(seed_textbox))
generator.manual_seed(int(seed_textbox))
max_frame_nums = 150
frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
if duration >= 10:
duration = 10
time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(device)
time_frames = video_transform(time_frames)
time_frames = {"frames": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
preds = self.time_detector(time_frames)
preds = torch.sigmoid(preds)
# duration
time_condition = [
-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1
for i in range(int(1024 / 10 * duration))
]
time_condition = time_condition + [-1] * (1024 - len(time_condition))
# w -> b c h w
time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
# Note that clip need fewer frames
frames = frames[::10]
images = self.image_processor(images=frames, return_tensors="pt").to(device)
image_embeddings = self.image_encoder(**images).image_embeds
image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
neg_image_embeddings = torch.zeros_like(image_embeddings)
image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
self.pipeline.set_ip_adapter_scale(ip_adapter_scale)
sample = self.pipeline(
prompt=prompt_textbox,
negative_prompt=negative_prompt_textbox,
ip_adapter_image_embeds=image_embeddings,
image=time_condition,
controlnet_conditioning_scale=float(temporal_scale),
num_inference_steps=sample_step_slider,
height=256,
width=1024,
output_type="pt",
generator=generator,
)
name = "output"
audio_img = sample.images[0]
audio = denormalize_spectrogram(audio_img)
audio = self.vocoder.inference(audio, lengths=160000)[0]
audio_save_path = osp.join(self.savedir_sample, "audio")
os.makedirs(audio_save_path, exist_ok=True)
audio = audio[: int(duration * 16000)]
save_path = osp.join(audio_save_path, f"{name}.wav")
sf.write(save_path, audio, 16000)
audio = AudioFileClip(osp.join(audio_save_path, f"{name}.wav"))
video = VideoFileClip(input_video)
audio = audio.subclip(0, duration)
video.audio = audio
video = video.subclip(0, duration)
video.write_videofile(osp.join(self.savedir_sample, f"{name}.mp4"))
save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4")
return save_sample_path
controller = FoleyController()
device = "cuda" if torch.cuda.is_available() else "cpu"
with gr.Blocks(css=css) as demo:
gr.HTML(
'<h1 style="height: 136px; display: flex; align-items: center; justify-content: space-around;"><span style="height: 100%; width:136px;"><img src="file/assets/foleycrafter.png" alt="logo" style="height: 100%; width:auto; object-fit: contain; margin: 0px 0px; padding: 0px 0px;"></span><strong style="font-size: 36px;">FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</strong></h1>'
)
gr.HTML(
'<p id="authors" style="text-align:center; font-size:24px;"> \
<a href="https://github.com/ymzhang0319">Yiming Zhang</a><sup>1</sup>,  \
<a href="https://github.com/VocodexElysium">Yicheng Gu</a><sup>2</sup>,  \
<a href="https://zengyh1900.github.io/">Yanhong Zeng</a><sup>1 †</sup>,  \
<a href="https://github.com/LeoXing1996/">Zhening Xing</a><sup>1</sup>,  \
<a href="https://github.com/HeCheng0625">Yuancheng Wang</a><sup>2</sup>,  \
<a href="https://drwuz.com/">Zhizheng Wu</a><sup>2</sup>,  \
<a href="https://chenkai.site/">Kai Chen</a><sup>1 †</sup>\
<br>\
<span>\
<sup>1</sup>Shanghai AI Laboratory \
<sup>2</sup>Chinese University of Hong Kong, Shenzhen \
†Corresponding author\
</span>\
</p>'
)
with gr.Row():
gr.Markdown(
"<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a>  " # noqa
"<a href='https://arxiv.org/abs/2407.01494/'>Paper</a>  "
"<a href='https://github.com/open-mmlab/foleycrafter'>Code</a>  "
"<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>"
)
with gr.Column(variant="panel"):
with gr.Row(equal_height=False):
with gr.Column():
with gr.Row():
init_img = gr.Video(label="Input Video")
with gr.Row():
prompt_textbox = gr.Textbox(value="", label="Prompt", lines=1)
with gr.Row():
negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)
with gr.Row():
ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1)
temporal_scale = gr.Slider(label="Temporal Align Scale", value=0.2, minimum=0.0, maximum=1.0)
with gr.Accordion("Sampling Settings", open=False):
with gr.Row():
sampler_dropdown = gr.Dropdown(
label="Sampling method",
choices=list(scheduler_dict.keys()),
value=list(scheduler_dict.keys())[0],
)
sample_step_slider = gr.Slider(
label="Sampling steps", value=25, minimum=10, maximum=100, step=1
)
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
with gr.Row():
seed_textbox = gr.Textbox(label="Seed", value=42)
seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)
generate_button = gr.Button(value="Generate", variant="primary")
with gr.Column():
result_video = gr.Video(label="Generated Audio", interactive=False)
with gr.Row():
gr.Markdown(
"<div style='word-spacing: 6px;'><font size='5'><b>Tips</b>: <br> \
1. With strong temporal visual cues in input video, you can scale up the <b>Temporal Align Scale</b>. <br>\
2. <b>Visual content scale</b> is the level of semantic alignment with visual content.</font></div> \
"
)
generate_button.click(
fn=controller.foley,
inputs=[
init_img,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
],
outputs=[result_video],
)
gr.Examples(
examples=[
["examples/gen3/case1.mp4", "", "", 1.0, 0.2, "DDIM", 25, 7.5, 33817921],
["examples/gen3/case3.mp4", "", "", 1.0, 0.2, "DDIM", 25, 7.5, 94667578],
["examples/gen3/case5.mp4", "", "", 0.75, 0.2, "DDIM", 25, 7.5, 92890876],
["examples/gen3/case6.mp4", "", "", 1.0, 0.2, "DDIM", 25, 7.5, 77015909],
],
inputs=[
init_img,
prompt_textbox,
negative_prompt_textbox,
ip_adapter_scale,
temporal_scale,
sampler_dropdown,
sample_step_slider,
cfg_scale_slider,
seed_textbox,
],
cache_examples=True,
outputs=[result_video],
fn=controller.foley,
)
demo.queue(10)
demo.launch(
server_name=args.server_name,
server_port=args.port,
share=args.share,
allowed_paths=["./assets/foleycrafter.png"],
)
================================================
FILE: foleycrafter/data/__init__.py
================================================
from .dataset import AudioSetStrong, CPU_Unpickler, VGGSound, dynamic_range_compression, get_mel, zero_rank_print
from .video_transforms import (
CenterCropVideo,
KineticsRandomCropResizeVideo,
NormalizeVideo,
RandomHorizontalFlipVideo,
TemporalRandomCrop,
ToTensorVideo,
UCFCenterCropVideo,
)
__all__ = [
"zero_rank_print",
"get_mel",
"dynamic_range_compression",
"CPU_Unpickler",
"AudioSetStrong",
"VGGSound",
"UCFCenterCropVideo",
"KineticsRandomCropResizeVideo",
"CenterCropVideo",
"NormalizeVideo",
"ToTensorVideo",
"RandomHorizontalFlipVideo",
"TemporalRandomCrop",
]
================================================
FILE: foleycrafter/data/dataset.py
================================================
import glob
import io
import pickle
import random
import numpy as np
import torch
import torch.distributed as dist
import torchaudio
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
def zero_rank_print(s):
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0):
print("### " + s, flush=True)
@torch.no_grad()
def get_mel(audio_data, audio_cfg):
# mel shape: (n_mels, T)
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=audio_cfg["sample_rate"],
n_fft=audio_cfg["window_size"],
win_length=audio_cfg["window_size"],
hop_length=audio_cfg["hop_size"],
center=True,
pad_mode="reflect",
power=2.0,
norm=None,
onesided=True,
n_mels=64,
f_min=audio_cfg["fmin"],
f_max=audio_cfg["fmax"],
).to(audio_data.device)
mel = mel(audio_data)
# we use log mel spectrogram as input
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
return mel # (T, n_mels)
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return normalize_fun(torch.clamp(x, min=clip_val) * C)
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "torch.storage" and name == "_load_from_bytes":
return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
else:
return super().find_class(module, name)
class AudioSetStrong(Dataset):
# read feature and audio
def __init__(
self,
data_path="data/AudioSetStrong/train/feature",
video_path="data/AudioSetStrong/train/video",
):
super().__init__()
self.data_path = data_path
self.data_list = list(self.data_path)
self.length = len(self.data_list)
# get video feature
self.video_path = video_path
vision_transform_list = [
transforms.Resize((128, 128)),
transforms.CenterCrop((112, 112)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
self.video_transform = transforms.Compose(vision_transform_list)
def get_batch(self, idx):
embeds = self.data_list[idx]
mel = embeds["mel"]
save_bsz = mel.shape[0]
audio_info = embeds["audio_info"]
text_embeds = embeds["text_embeds"]
# audio_info['label_list'] = np.array(audio_info['label_list'])
audio_info_array = np.array(audio_info["label_list"])
prompts = []
for i in range(save_bsz):
prompts.append(", ".join(audio_info_array[i, : audio_info["event_num"][i]].tolist()))
return mel, audio_info, text_embeds, prompts
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx)
break
except Exception:
zero_rank_print(" >>> load error <<<")
idx = random.randint(0, self.length - 1)
sample = {
"mel": mel,
"audio_info": audio_info,
"text_embeds": text_embeds,
"prompts": prompts,
"videos": videos,
}
return sample
class VGGSound(Dataset):
# read feature and audio
def __init__(
self,
data_path="data/VGGSound/train/video",
visual_data_path="data/VGGSound/train/feature",
):
super().__init__()
self.data_path = data_path
self.visual_data_path = visual_data_path
self.embeds_list = glob.glob(f"{self.data_path}/*.pt")
self.visual_list = glob.glob(f"{self.visual_data_path}/*.pt")
self.length = len(self.embeds_list)
def get_batch(self, idx):
embeds = torch.load(self.embeds_list[idx], map_location="cpu")
visual_embeds = torch.load(self.visual_list[idx], map_location="cpu")
# audio_embeds = embeds['audio_embeds']
visual_embeds = visual_embeds["visual_embeds"]
# video_name = embeds["video_name"]
text = embeds["text"]
mel = embeds["mel"]
audio = mel
return visual_embeds, audio, text
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
visual_embeds, audio, text = self.get_batch(idx)
break
except Exception:
zero_rank_print("load error")
idx = random.randint(0, self.length - 1)
sample = {"visual_embeds": visual_embeds, "audio": audio, "text": text}
return sample
================================================
FILE: foleycrafter/data/video_transforms.py
================================================
import numbers
import random
import torch
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
_, _, H, W = clip.shape
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
"""
Slide along the long edge, with the short edge as crop size
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
# long_edge = w
short_edge = h
else:
# long_edge = h
short_edge = w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw)
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class UCFCenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
"""
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
"""
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == "__main__":
import os
import numpy as np
import torchvision.io as io
from torchvision import transforms
from torchvision.utils import save_image
vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW")
trans = transforms.Compose(
[
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
select_vframes = vframes[frame_indice]
select_vframes_trans = trans(select_vframes)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(
select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1)
)
================================================
FILE: foleycrafter/models/adapters/attention_processor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0WithProjection(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.before_proj_size = 1024
self.after_proj_size = 768
self.visual_proj = nn.Linear(self.before_proj_size, self.after_proj_size)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
# encoder_hidden_states = self.visual_proj(encoder_hidden_states)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
# print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
## for controlnet
class CNAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, num_tokens=4):
self.num_tokens = num_tokens
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CNAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, num_tokens=4):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.num_tokens = num_tokens
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
================================================
FILE: foleycrafter/models/adapters/ip_adapter.py
================================================
import torch
import torch.nn as nn
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def zero_initialize(module):
for param in module.parameters():
param.data.zero_()
def zero_initialize_last_layer(module):
last_layer = None
for module_name, layer in module.named_modules():
if isinstance(layer, torch.nn.Linear):
last_layer = layer
if last_layer is not None:
last_layer.weight.data.zero_()
last_layer.bias.data.zero_()
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim),
)
# zero initialize the last layer
# self.zero_initialize_last_layer()
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class V2AMapperMLP(torch.nn.Module):
def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim),
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class TimeProjModel(torch.nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums: int = 64):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.position_dim = frame_nums
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
if feature_type == "text-only":
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
# self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(
self,
boxes,
masks,
positive_embeddings=None,
):
masks = masks.unsqueeze(-1)
# # embedding position (it may includes padding as placeholder)
# xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# # learnable null embedding
# xyxy_null = self.null_position_feature.view(1, 1, -1)
# # replace padding with learnable null embedding
# xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
time_embeds = boxes
# positionet with text only information
if positive_embeddings is not None:
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1))
# positionet with text and image information
else:
raise NotImplementedError
return objs
================================================
FILE: foleycrafter/models/adapters/resampler.py
================================================
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
================================================
FILE: foleycrafter/models/adapters/transformer.py
================================================
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = attention_head_dim
self.scale = self.head_dim**-0.5
self.dropout = attention_dropout
self.inner_dim = self.head_dim * self.num_heads
self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
self.out_proj = nn.Linear(self.inner_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class MLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mult=4):
super().__init__()
self.activation_fn = nn.SiLU()
self.fc1 = nn.Linear(hidden_size, intermediate_size * mult)
self.fc2 = nn.Linear(intermediate_size * mult, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Transformer(nn.Module):
def __init__(self, depth=12):
super().__init__()
self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
causal_attention_mask: torch.Tensor = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
for layer in self.layers:
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
return hidden_states
class TransformerBlock(nn.Module):
def __init__(
self,
hidden_size=512,
num_attention_heads=12,
attention_head_dim=64,
attention_dropout=0.0,
dropout=0.0,
eps=1e-5,
):
super().__init__()
self.embed_dim = hidden_size
self.self_attn = Attention(
hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
causal_attention_mask: torch.Tensor = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs[0]
class DiffusionTransformerBlock(nn.Module):
def __init__(
self,
hidden_size=512,
num_attention_heads=12,
attention_head_dim=64,
attention_dropout=0.0,
dropout=0.0,
eps=1e-5,
):
super().__init__()
self.embed_dim = hidden_size
self.self_attn = Attention(
hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
self.output_token = nn.Parameter(torch.randn(1, hidden_size))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
causal_attention_mask: torch.Tensor = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)
hidden_states = torch.cat([output_token, hidden_states], dim=1)
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs[0][:, 0:1, ...]
class V2AMapperMLP(nn.Module):
def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
super().__init__()
self.linear = nn.Linear(input_dim, input_dim * expansion_rate)
self.silu = nn.SiLU()
self.layer_norm = nn.LayerNorm(input_dim * expansion_rate)
self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim)
def forward(self, x):
x = self.linear(x)
x = self.silu(x)
x = self.layer_norm(x)
x = self.linear2(x)
return x
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
self.zero_initialize_last_layer()
def zero_initialize_last_layer(module):
last_layer = None
for module_name, layer in module.named_modules():
if isinstance(layer, torch.nn.Linear):
last_layer = layer
if last_layer is not None:
last_layer.weight.data.zero_()
last_layer.bias.data.zero_()
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class VisionAudioAdapter(torch.nn.Module):
def __init__(
self,
embedding_size=768,
expand_dim=4,
token_num=4,
):
super().__init__()
self.mapper = V2AMapperMLP(
embedding_size,
embedding_size,
expansion_rate=expand_dim,
)
self.proj = ImageProjModel(
cross_attention_dim=embedding_size,
clip_embeddings_dim=embedding_size,
clip_extra_context_tokens=token_num,
)
def forward(self, image_embeds):
image_embeds = self.mapper(image_embeds)
image_embeds = self.proj(image_embeds)
return image_embeds
================================================
FILE: foleycrafter/models/adapters/utils.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
attn_maps = {}
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if name.split(".")[-1].startswith("attn2"):
module.register_forward_hook(hook_fn(name))
return unet
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0)
attn_map = attn_map.permute(1, 0)
temp_size = None
for i in range(0, 5):
scale = 2**i
if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[1] * 64:
temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8))
break
assert temp_size is not None, "temp_size cannot is None"
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
attn_map = F.interpolate(
attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode="bilinear", align_corners=False
)[0]
attn_map = torch.softmax(attn_map, dim=0)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
attn_map = upscale(attn_map, image_size)
net_attn_maps.append(attn_map)
net_attn_maps = torch.mean(torch.stack(net_attn_maps, dim=0), dim=0)
return net_attn_maps
def attnmaps2images(net_attn_maps):
# total_attn_scores = 0
images = []
for attn_map in net_attn_maps:
attn_map = attn_map.cpu().numpy()
# total_attn_scores += attn_map.mean().item()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
# print("norm: ", normalized_attn_map.shape)
image = Image.fromarray(normalized_attn_map)
# image = fix_save_attn_map(attn_map)
images.append(image)
# print(total_attn_scores)
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
================================================
FILE: foleycrafter/models/auffusion/attention.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.utils.torch_utils import maybe_allow_in_graph
from foleycrafter.models.auffusion.attention_processor import Attention
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TODO(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if (only_cross_attention and not double_self_attention) else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if self.use_ada_layer_norm:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if self.use_ada_layer_norm_continuous:
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net:
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
return hidden_states
================================================
FILE: foleycrafter/models/auffusion/attention_processor.py
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from importlib import import_module
from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
@maybe_allow_in_graph
class Attention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`):
The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
upcast_attention (`bool`, *optional*, defaults to False):
Set to `True` to upcast the attention computation to `float32`.
upcast_softmax (`bool`, *optional*, defaults to False):
Set to `True` to upcast the softmax computation to `float32`.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the group norm in the cross attention.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
norm_num_groups (`int`, *optional*, defaults to `None`):
The number of groups to use for the group norm in the attention.
spatial_norm_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the spatial normalization.
out_bias (`bool`, *optional*, defaults to `True`):
Set to `True` to use a bias in the output linear layer.
scale_qk (`bool`, *optional*, defaults to `True`):
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
only_cross_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
`added_kv_proj_dim` is not `None`.
eps (`float`, *optional*, defaults to 1e-5):
An additional value added to the denominator in group normalization that is used for numerical stability.
rescale_output_factor (`float`, *optional*, defaults to 1.0):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else:
self.group_norm = None
if spatial_norm_dim is not None:
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
else:
self.spatial_norm = None
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = self.cross_attention_dim
self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
r"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor,
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
if is_lora:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
if is_lora:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
attn_processor_class = (
CustomDiffusionAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else CustomDiffusionAttnProcessor
)
processor = attn_processor_class(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
self.set_processor(processor)
def set_attention_slice(self, slice_size: int) -> None:
r"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = AttnAddedKVProcessor()
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
r"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
"""
if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
)
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
r"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
# serialization format for LoRA Attention Processors. It should be deleted once the integration
# with PEFT is completed.
is_lora_activated = {
name: module.lora_layer is not None
for name, module in self.named_modules()
if hasattr(module, "lora_layer")
}
# 1. if no layer has a LoRA activated we can return the processor as usual
if not any(is_lora_activated.values()):
return self.processor
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
is_lora_activated.pop("add_k_proj", None)
is_lora_activated.pop("add_v_proj", None)
# 2. else it is not possible that only some layers have LoRA activated
if not all(is_lora_activated.values()):
raise ValueError(
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
)
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
non_lora_processor_cls_name = self.processor.__class__.__name__
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
hidden_size = self.inner_dim
# now create a LoRA attention processor from the LoRA layers
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
kwargs = {
"cross_attention_dim": self.cross_attention_dim,
"rank": self.to_q.lora_layer.rank,
"network_alpha": self.to_q.lora_layer.network_alpha,
"q_rank": self.to_q.lora_layer.rank,
"q_hidden_size": self.to_q.lora_layer.out_features,
"k_rank": self.to_k.lora_layer.rank,
"k_hidden_size": self.to_k.lora_layer.out_features,
"v_rank": self.to_v.lora_layer.rank,
"v_hidden_size": self.to_v.lora_layer.out_features,
"out_rank": self.to_out[0].lora_layer.rank,
"out_hidden_size": self.to_out[0].lora_layer.out_features,
}
if hasattr(self.processor, "attention_op"):
kwargs["attention_op"] = self.processor.attention_op
lora_processor = lora_processor_cls(hidden_size, **kwargs)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
lora_processor = lora_processor_cls(
hidden_size,
cross_attention_dim=self.add_k_proj.weight.shape[0],
rank=self.to_q.lora_layer.rank,
network_alpha=self.to_q.lora_layer.network_alpha,
)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
# only save if used
if self.add_k_proj.lora_layer is not None:
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
else:
lora_processor.add_k_proj_lora = None
lora_processor.add_v_proj_lora = None
else:
raise ValueError(f"{lora_processor_cls} does not exist.")
return lora_processor
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
is the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
) -> torch.Tensor:
r"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
`Attention` class.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
self.fused_projections = fuse
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionAttnProcessor(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = True,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttnAddedKVProcessor:
r"""
Processor for performing attention-related computations with extra learnable key and value matrices for the text
encoder.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class AttnAddedKVProcessor2_0:
r"""
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
learnable key and value matrices for the text encoder.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_s
gitextract_36et0ept/
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── app.py
├── foleycrafter/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── video_transforms.py
│ ├── models/
│ │ ├── adapters/
│ │ │ ├── attention_processor.py
│ │ │ ├── ip_adapter.py
│ │ │ ├── resampler.py
│ │ │ ├── transformer.py
│ │ │ └── utils.py
│ │ ├── auffusion/
│ │ │ ├── attention.py
│ │ │ ├── attention_processor.py
│ │ │ ├── dual_transformer_2d.py
│ │ │ ├── loaders/
│ │ │ │ ├── ip_adapter.py
│ │ │ │ └── unet.py
│ │ │ ├── resnet.py
│ │ │ ├── transformer_2d.py
│ │ │ └── unet_2d_blocks.py
│ │ ├── auffusion_unet.py
│ │ ├── onset/
│ │ │ ├── __init__.py
│ │ │ ├── r2plus1d_18.py
│ │ │ ├── resnet.py
│ │ │ ├── torch_utils.py
│ │ │ └── video_onset_net.py
│ │ └── time_detector/
│ │ ├── model.py
│ │ └── resnet.py
│ ├── pipelines/
│ │ ├── auffusion_pipeline.py
│ │ └── pipeline_controlnet.py
│ └── utils/
│ ├── converter.py
│ ├── spec_to_mel.py
│ └── util.py
├── inference.py
├── pyproject.toml
└── requirements/
└── environment.yaml
SYMBOL INDEX (700 symbols across 29 files)
FILE: app.py
class FoleyController (line 55) | class FoleyController:
method __init__ (line 56) | def __init__(self):
method load_model (line 70) | def load_model(self):
method foley (line 127) | def foley(
FILE: foleycrafter/data/dataset.py
function zero_rank_print (line 14) | def zero_rank_print(s):
function get_mel (line 20) | def get_mel(audio_data, audio_cfg):
function dynamic_range_compression (line 42) | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=...
class CPU_Unpickler (line 51) | class CPU_Unpickler(pickle.Unpickler):
method find_class (line 52) | def find_class(self, module, name):
class AudioSetStrong (line 59) | class AudioSetStrong(Dataset):
method __init__ (line 61) | def __init__(
method get_batch (line 79) | def get_batch(self, idx):
method __len__ (line 94) | def __len__(self):
method __getitem__ (line 97) | def __getitem__(self, idx):
class VGGSound (line 115) | class VGGSound(Dataset):
method __init__ (line 117) | def __init__(
method get_batch (line 129) | def get_batch(self, idx):
method __len__ (line 143) | def __len__(self):
method __getitem__ (line 146) | def __getitem__(self, idx):
FILE: foleycrafter/data/video_transforms.py
function _is_tensor_video_clip (line 7) | def _is_tensor_video_clip(clip):
function crop (line 17) | def crop(clip, i, j, h, w):
function resize (line 27) | def resize(clip, target_size, interpolation_mode):
function resize_scale (line 33) | def resize_scale(clip, target_size, interpolation_mode):
function resized_crop (line 41) | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
function center_crop (line 61) | def center_crop(clip, crop_size):
function random_shift_crop (line 74) | def random_shift_crop(clip):
function to_tensor (line 96) | def to_tensor(clip):
function normalize (line 112) | def normalize(clip, mean, std, inplace=False):
function hflip (line 132) | def hflip(clip):
class RandomCropVideo (line 144) | class RandomCropVideo:
method __init__ (line 145) | def __init__(self, size):
method __call__ (line 151) | def __call__(self, clip):
method get_params (line 162) | def get_params(self, clip):
method __repr__ (line 177) | def __repr__(self) -> str:
class UCFCenterCropVideo (line 181) | class UCFCenterCropVideo:
method __init__ (line 182) | def __init__(
method __call__ (line 196) | def __call__(self, clip):
method __repr__ (line 208) | def __repr__(self) -> str:
class KineticsRandomCropResizeVideo (line 212) | class KineticsRandomCropResizeVideo:
method __init__ (line 217) | def __init__(
method __call__ (line 231) | def __call__(self, clip):
class CenterCropVideo (line 237) | class CenterCropVideo:
method __init__ (line 238) | def __init__(
method __call__ (line 252) | def __call__(self, clip):
method __repr__ (line 263) | def __repr__(self) -> str:
class NormalizeVideo (line 267) | class NormalizeVideo:
method __init__ (line 276) | def __init__(self, mean, std, inplace=False):
method __call__ (line 281) | def __call__(self, clip):
method __repr__ (line 288) | def __repr__(self) -> str:
class ToTensorVideo (line 292) | class ToTensorVideo:
method __init__ (line 298) | def __init__(self):
method __call__ (line 301) | def __call__(self, clip):
method __repr__ (line 310) | def __repr__(self) -> str:
class RandomHorizontalFlipVideo (line 314) | class RandomHorizontalFlipVideo:
method __init__ (line 321) | def __init__(self, p=0.5):
method __call__ (line 324) | def __call__(self, clip):
method __repr__ (line 335) | def __repr__(self) -> str:
class TemporalRandomCrop (line 342) | class TemporalRandomCrop(object):
method __init__ (line 349) | def __init__(self, size):
method __call__ (line 352) | def __call__(self, total_frames):
FILE: foleycrafter/models/adapters/attention_processor.py
class AttnProcessor (line 11) | class AttnProcessor(nn.Module):
method __init__ (line 16) | def __init__(
method __call__ (line 23) | def __call__(
class IPAttnProcessor (line 84) | class IPAttnProcessor(nn.Module):
method __init__ (line 98) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
method __call__ (line 109) | def __call__(
class AttnProcessor2_0 (line 191) | class AttnProcessor2_0(torch.nn.Module):
method __init__ (line 196) | def __init__(
method __call__ (line 205) | def __call__(
class AttnProcessor2_0WithProjection (line 280) | class AttnProcessor2_0WithProjection(torch.nn.Module):
method __init__ (line 285) | def __init__(
method __call__ (line 297) | def __call__(
class IPAttnProcessor2_0 (line 373) | class IPAttnProcessor2_0(torch.nn.Module):
method __init__ (line 387) | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, n...
method __call__ (line 401) | def __call__(
class CNAttnProcessor (line 505) | class CNAttnProcessor:
method __init__ (line 510) | def __init__(self, num_tokens=4):
method __call__ (line 513) | def __call__(self, attn, hidden_states, encoder_hidden_states=None, at...
class CNAttnProcessor2_0 (line 570) | class CNAttnProcessor2_0:
method __init__ (line 575) | def __init__(self, num_tokens=4):
method __call__ (line 580) | def __call__(
FILE: foleycrafter/models/adapters/ip_adapter.py
class IPAdapter (line 5) | class IPAdapter(torch.nn.Module):
method __init__ (line 8) | def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=...
method forward (line 17) | def forward(self, noisy_latents, timesteps, encoder_hidden_states, ima...
method load_from_checkpoint (line 24) | def load_from_checkpoint(self, ckpt_path: str):
class ImageProjModel (line 46) | class ImageProjModel(torch.nn.Module):
method __init__ (line 49) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
method forward (line 57) | def forward(self, image_embeds):
class MLPProjModel (line 66) | class MLPProjModel(torch.nn.Module):
method zero_initialize (line 69) | def zero_initialize(module):
method zero_initialize_last_layer (line 73) | def zero_initialize_last_layer(module):
method __init__ (line 83) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
method forward (line 95) | def forward(self, image_embeds):
class V2AMapperMLP (line 100) | class V2AMapperMLP(torch.nn.Module):
method __init__ (line 101) | def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, m...
method forward (line 110) | def forward(self, image_embeds):
class TimeProjModel (line 115) | class TimeProjModel(torch.nn.Module):
method __init__ (line 116) | def __init__(self, positive_len, out_dim, feature_type="text-only", fr...
method forward (line 156) | def forward(
FILE: foleycrafter/models/adapters/resampler.py
function FeedForward (line 13) | def FeedForward(dim, mult=4):
function reshape_tensor (line 23) | def reshape_tensor(x, heads):
class PerceiverAttention (line 34) | class PerceiverAttention(nn.Module):
method __init__ (line 35) | def __init__(self, *, dim, dim_head=64, heads=8):
method forward (line 49) | def forward(self, x, latents):
class Resampler (line 81) | class Resampler(nn.Module):
method __init__ (line 82) | def __init__(
method forward (line 127) | def forward(self, x):
function masked_mean (line 150) | def masked_mean(t, *, dim, mask=None):
FILE: foleycrafter/models/adapters/transformer.py
class Attention (line 8) | class Attention(nn.Module):
method __init__ (line 11) | def __init__(self, hidden_size, num_attention_heads, attention_head_di...
method _shape (line 27) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
method forward (line 30) | def forward(
class MLP (line 109) | class MLP(nn.Module):
method __init__ (line 110) | def __init__(self, hidden_size, intermediate_size, mult=4):
method forward (line 116) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class Transformer (line 123) | class Transformer(nn.Module):
method __init__ (line 124) | def __init__(self, depth=12):
method forward (line 128) | def forward(
class TransformerBlock (line 156) | class TransformerBlock(nn.Module):
method __init__ (line 157) | def __init__(
method forward (line 175) | def forward(
class DiffusionTransformerBlock (line 216) | class DiffusionTransformerBlock(nn.Module):
method __init__ (line 217) | def __init__(
method forward (line 236) | def forward(
class V2AMapperMLP (line 279) | class V2AMapperMLP(nn.Module):
method __init__ (line 280) | def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
method forward (line 287) | def forward(self, x):
class ImageProjModel (line 296) | class ImageProjModel(torch.nn.Module):
method __init__ (line 299) | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024,...
method zero_initialize_last_layer (line 309) | def zero_initialize_last_layer(module):
method forward (line 319) | def forward(self, image_embeds):
class VisionAudioAdapter (line 328) | class VisionAudioAdapter(torch.nn.Module):
method __init__ (line 329) | def __init__(
method forward (line 349) | def forward(self, image_embeds):
FILE: foleycrafter/models/adapters/utils.py
function hook_fn (line 10) | def hook_fn(name):
function register_cross_attention_hook (line 19) | def register_cross_attention_hook(unet):
function upscale (line 27) | def upscale(attn_map, target_size):
function get_net_attn_map (line 50) | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=Fals...
function attnmaps2images (line 65) | def attnmaps2images(net_attn_maps):
function is_torch2_available (line 85) | def is_torch2_available():
FILE: foleycrafter/models/auffusion/attention.py
function _chunked_feed_forward (line 29) | def _chunked_feed_forward(
class GatedSelfAttentionDense (line 55) | class GatedSelfAttentionDense(nn.Module):
method __init__ (line 66) | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_h...
method forward (line 83) | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
class BasicTransformerBlock (line 97) | class BasicTransformerBlock(nn.Module):
method __init__ (line 132) | def __init__(
method set_chunk_feed_forward (line 279) | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int =...
method forward (line 284) | def forward(
class TemporalBasicTransformerBlock (line 408) | class TemporalBasicTransformerBlock(nn.Module):
method __init__ (line 420) | def __init__(
method set_chunk_feed_forward (line 474) | def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
method forward (line 480) | def forward(
class SkipFFTransformerBlock (line 538) | class SkipFFTransformerBlock(nn.Module):
method __init__ (line 539) | def __init__(
method forward (line 581) | def forward(self, hidden_states, encoder_hidden_states, cross_attentio...
class FeedForward (line 610) | class FeedForward(nn.Module):
method __init__ (line 624) | def __init__(
method forward (line 661) | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> ...
FILE: foleycrafter/models/auffusion/attention_processor.py
class Attention (line 40) | class Attention(nn.Module):
method __init__ (line 91) | def __init__(
method set_use_memory_efficient_attention_xformers (line 216) | def set_use_memory_efficient_attention_xformers(
method set_attention_slice (line 350) | def set_attention_slice(self, slice_size: int) -> None:
method set_processor (line 378) | def set_processor(self, processor: "AttnProcessor", _remove_lora: bool...
method get_processor (line 413) | def get_processor(self, return_deprecated_lora: bool = False) -> "Atte...
method forward (line 503) | def forward(
method batch_to_head_dim (line 537) | def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
method head_to_batch_dim (line 554) | def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) ->...
method get_attention_scores (line 577) | def get_attention_scores(
method prepare_attention_mask (line 624) | def prepare_attention_mask(
method norm_encoder_hidden_states (line 671) | def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tens...
method fuse_projections (line 701) | def fuse_projections(self, fuse=True):
class AttnProcessor (line 727) | class AttnProcessor:
method __call__ (line 732) | def __call__(
class CustomDiffusionAttnProcessor (line 796) | class CustomDiffusionAttnProcessor(nn.Module):
method __init__ (line 815) | def __init__(
method __call__ (line 841) | def __call__(
class AttnAddedKVProcessor (line 900) | class AttnAddedKVProcessor:
method __call__ (line 906) | def __call__(
class AttnAddedKVProcessor2_0 (line 964) | class AttnAddedKVProcessor2_0:
method __init__ (line 970) | def __init__(self):
method __call__ (line 976) | def __call__(
class XFormersAttnAddedKVProcessor (line 1037) | class XFormersAttnAddedKVProcessor:
method __init__ (line 1049) | def __init__(self, attention_op: Optional[Callable] = None):
method __call__ (line 1052) | def __call__(
class XFormersAttnProcessor (line 1108) | class XFormersAttnProcessor:
method __init__ (line 1120) | def __init__(self, attention_op: Optional[Callable] = None):
method __call__ (line 1123) | def __call__(
class AttnProcessor2_0 (line 1199) | class AttnProcessor2_0:
method __init__ (line 1204) | def __init__(self):
method __call__ (line 1208) | def __call__(
class FusedAttnProcessor2_0 (line 1285) | class FusedAttnProcessor2_0:
method __init__ (line 1298) | def __init__(self):
method __call__ (line 1304) | def __call__(
class CustomDiffusionXFormersAttnProcessor (line 1382) | class CustomDiffusionXFormersAttnProcessor(nn.Module):
method __init__ (line 1405) | def __init__(
method __call__ (line 1433) | def __call__(
class CustomDiffusionAttnProcessor2_0 (line 1498) | class CustomDiffusionAttnProcessor2_0(nn.Module):
method __init__ (line 1518) | def __init__(
method __call__ (line 1544) | def __call__(
class SlicedAttnProcessor (line 1612) | class SlicedAttnProcessor:
method __init__ (line 1622) | def __init__(self, slice_size: int):
method __call__ (line 1625) | def __call__(
class SlicedAttnAddedKVProcessor (line 1699) | class SlicedAttnAddedKVProcessor:
method __init__ (line 1709) | def __init__(self, slice_size):
method __call__ (line 1712) | def __call__(
class SpatialNorm (line 1791) | class SpatialNorm(nn.Module):
method __init__ (line 1802) | def __init__(
method forward (line 1812) | def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torc...
class LoRAAttnProcessor (line 1821) | class LoRAAttnProcessor(nn.Module):
method __init__ (line 1838) | def __init__(
method __call__ (line 1872) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
class LoRAAttnProcessor2_0 (line 1893) | class LoRAAttnProcessor2_0(nn.Module):
method __init__ (line 1911) | def __init__(
method __call__ (line 1947) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
class LoRAXFormersAttnProcessor (line 1968) | class LoRAXFormersAttnProcessor(nn.Module):
method __init__ (line 1990) | def __init__(
method __call__ (line 2026) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
class LoRAAttnAddedKVProcessor (line 2047) | class LoRAAttnAddedKVProcessor(nn.Module):
method __init__ (line 2065) | def __init__(
method __call__ (line 2085) | def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, ...
class IPAdapterAttnProcessor (line 2106) | class IPAdapterAttnProcessor(nn.Module):
method __init__ (line 2121) | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4...
method __call__ (line 2132) | def __call__(
class VPTemporalAdapterAttnProcessor2_0 (line 2216) | class VPTemporalAdapterAttnProcessor2_0(torch.nn.Module):
method __init__ (line 2238) | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(...
method __call__ (line 2266) | def __call__(
class IPAdapterAttnProcessor2_0 (line 2450) | class IPAdapterAttnProcessor2_0(torch.nn.Module):
method __init__ (line 2465) | def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(...
method __call__ (line 2492) | def __call__(
FILE: foleycrafter/models/auffusion/dual_transformer_2d.py
class DualTransformer2DModel (line 21) | class DualTransformer2DModel(nn.Module):
method __init__ (line 48) | def __init__(
method forward (line 97) | def forward(
FILE: foleycrafter/models/auffusion/loaders/ip_adapter.py
class IPAdapterMixin (line 51) | class IPAdapterMixin:
method load_ip_adapter (line 55) | def load_ip_adapter(
method set_ip_adapter_scale (line 235) | def set_ip_adapter_scale(self, scale):
method unload_ip_adapter (line 257) | def unload_ip_adapter(self):
class VPAdapterMixin (line 289) | class VPAdapterMixin:
method load_ip_adapter (line 293) | def load_ip_adapter(
method set_ip_adapter_scale (line 473) | def set_ip_adapter_scale(self, scale):
method unload_ip_adapter (line 495) | def unload_ip_adapter(self):
FILE: foleycrafter/models/auffusion/loaders/unet.py
class VPAdapterImageProjection (line 53) | class VPAdapterImageProjection(nn.Module):
method __init__ (line 54) | def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Modul...
method forward (line 58) | def forward(self, image_embeds: List[torch.FloatTensor]):
class MultiIPAdapterImageProjection (line 88) | class MultiIPAdapterImageProjection(nn.Module):
method __init__ (line 89) | def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Modul...
method forward (line 93) | def forward(self, image_embeds: List[torch.FloatTensor]):
class UNet2DConditionLoadersMixin (line 132) | class UNet2DConditionLoadersMixin:
method load_attn_procs (line 141) | def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union...
method convert_state_dict_legacy_attn_format (line 454) | def convert_state_dict_legacy_attn_format(self, state_dict, network_al...
method save_attn_procs (line 481) | def save_attn_procs(
method fuse_lora (line 583) | def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=N...
method _fuse_lora_apply (line 588) | def _fuse_lora_apply(self, module, adapter_names=None):
method unfuse_lora (line 600) | def unfuse_lora(self):
method _unfuse_lora_apply (line 603) | def _unfuse_lora_apply(self, module):
method set_adapters (line 608) | def set_adapters(
method disable_lora (line 656) | def disable_lora(self):
method enable_lora (line 679) | def enable_lora(self):
method delete_adapters (line 702) | def delete_adapters(self, adapter_names: Union[List[str], str]):
method _convert_ip_adapter_image_proj_to_diffusers (line 738) | def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_...
method _convert_ip_adapter_attn_to_diffusers_VPAdapter (line 786) | def _convert_ip_adapter_attn_to_diffusers_VPAdapter(self, state_dicts,...
method _convert_ip_adapter_attn_to_diffusers (line 873) | def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_m...
method _load_ip_adapter_weights (line 958) | def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
method _load_ip_adapter_weights_VPAdapter (line 975) | def _load_ip_adapter_weights_VPAdapter(self, state_dicts, low_cpu_mem_...
FILE: foleycrafter/models/auffusion/resnet.py
class ResnetBlock2D (line 45) | class ResnetBlock2D(nn.Module):
method __init__ (line 76) | def __init__(
method forward (line 183) | def forward(
function rearrange_dims (line 265) | def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
class Conv1dBlock (line 276) | class Conv1dBlock(nn.Module):
method __init__ (line 288) | def __init__(
method forward (line 302) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
class ResidualTemporalBlock1D (line 312) | class ResidualTemporalBlock1D(nn.Module):
method __init__ (line 324) | def __init__(
method forward (line 343) | def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
class TemporalConvLayer (line 359) | class TemporalConvLayer(nn.Module):
method __init__ (line 370) | def __init__(
method forward (line 411) | def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) ->...
class TemporalResnetBlock (line 430) | class TemporalResnetBlock(nn.Module):
method __init__ (line 442) | def __init__(
method forward (line 496) | def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTe...
class SpatioTemporalResBlock (line 523) | class SpatioTemporalResBlock(nn.Module):
method __init__ (line 541) | def __init__(
method forward (line 574) | def forward(
class AlphaBlender (line 607) | class AlphaBlender(nn.Module):
method __init__ (line 621) | def __init__(
method get_alpha (line 641) | def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) ->...
method forward (line 672) | def forward(
FILE: foleycrafter/models/auffusion/transformer_2d.py
class Transformer2DModelOutput (line 31) | class Transformer2DModelOutput(BaseOutput):
class Transformer2DModel (line 44) | class Transformer2DModel(ModelMixin, ConfigMixin):
method __init__ (line 75) | def __init__(
method _set_gradient_checkpointing (line 242) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 246) | def forward(
FILE: foleycrafter/models/auffusion/unet_2d_blocks.py
function get_down_block (line 42) | def get_down_block(
function get_up_block (line 251) | def get_up_block(
class AutoencoderTinyBlock (line 476) | class AutoencoderTinyBlock(nn.Module):
method __init__ (line 492) | def __init__(self, in_channels: int, out_channels: int, act_fn: str):
method forward (line 509) | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
class UNetMidBlock2D (line 513) | class UNetMidBlock2D(nn.Module):
method __init__ (line 544) | def __init__(
method forward (line 628) | def forward(self, hidden_states: torch.FloatTensor, temb: Optional[tor...
class UNetMidBlock2DCrossAttn (line 638) | class UNetMidBlock2DCrossAttn(nn.Module):
method __init__ (line 639) | def __init__(
method forward (line 732) | def forward(
class UNetMidBlock2DSimpleCrossAttn (line 784) | class UNetMidBlock2DSimpleCrossAttn(nn.Module):
method __init__ (line 785) | def __init__(
method forward (line 869) | def forward(
class AttnDownBlock2D (line 908) | class AttnDownBlock2D(nn.Module):
method __init__ (line 909) | def __init__(
method forward (line 1000) | def forward(
class CrossAttnDownBlock2D (line 1031) | class CrossAttnDownBlock2D(nn.Module):
method __init__ (line 1032) | def __init__(
method forward (line 1124) | def forward(
class DownBlock2D (line 1193) | class DownBlock2D(nn.Module):
method __init__ (line 1194) | def __init__(
method forward (line 1245) | def forward(
class DownEncoderBlock2D (line 1281) | class DownEncoderBlock2D(nn.Module):
method __init__ (line 1282) | def __init__(
method forward (line 1330) | def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0...
class AttnDownEncoderBlock2D (line 1341) | class AttnDownEncoderBlock2D(nn.Module):
method __init__ (line 1342) | def __init__(
method forward (line 1413) | def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0...
class AttnSkipDownBlock2D (line 1426) | class AttnSkipDownBlock2D(nn.Module):
method __init__ (line 1427) | def __init__(
method forward (line 1507) | def forward(
class SkipDownBlock2D (line 1534) | class SkipDownBlock2D(nn.Module):
method __init__ (line 1535) | def __init__(
method forward (line 1594) | def forward(
class ResnetDownsampleBlock2D (line 1619) | class ResnetDownsampleBlock2D(nn.Module):
method __init__ (line 1620) | def __init__(
method forward (line 1683) | def forward(
class SimpleCrossAttnDownBlock2D (line 1719) | class SimpleCrossAttnDownBlock2D(nn.Module):
method __init__ (line 1720) | def __init__(
method forward (line 1814) | def forward(
class KDownBlock2D (line 1879) | class KDownBlock2D(nn.Module):
method __init__ (line 1880) | def __init__(
method forward (line 1925) | def forward(
class KCrossAttnDownBlock2D (line 1959) | class KCrossAttnDownBlock2D(nn.Module):
method __init__ (line 1960) | def __init__(
method forward (line 2024) | def forward(
class AttnUpBlock2D (line 2086) | class AttnUpBlock2D(nn.Module):
method __init__ (line 2087) | def __init__(
method forward (line 2178) | def forward(
class CrossAttnUpBlock2D (line 2206) | class CrossAttnUpBlock2D(nn.Module):
method __init__ (line 2207) | def __init__(
method forward (line 2298) | def forward(
class UpBlock2D (line 2380) | class UpBlock2D(nn.Module):
method __init__ (line 2381) | def __init__(
method forward (line 2430) | def forward(
class UpDecoderBlock2D (line 2490) | class UpDecoderBlock2D(nn.Module):
method __init__ (line 2491) | def __init__(
method forward (line 2537) | def forward(
class AttnUpDecoderBlock2D (line 2550) | class AttnUpDecoderBlock2D(nn.Module):
method __init__ (line 2551) | def __init__(
method forward (line 2621) | def forward(
class AttnSkipUpBlock2D (line 2636) | class AttnSkipUpBlock2D(nn.Module):
method __init__ (line 2637) | def __init__(
method forward (line 2730) | def forward(
class SkipUpBlock2D (line 2766) | class SkipUpBlock2D(nn.Module):
method __init__ (line 2767) | def __init__(
method forward (line 2838) | def forward(
class ResnetUpsampleBlock2D (line 2871) | class ResnetUpsampleBlock2D(nn.Module):
method __init__ (line 2872) | def __init__(
method forward (line 2940) | def forward(
class SimpleCrossAttnUpBlock2D (line 2980) | class SimpleCrossAttnUpBlock2D(nn.Module):
method __init__ (line 2981) | def __init__(
method forward (line 3079) | def forward(
class KUpBlock2D (line 3146) | class KUpBlock2D(nn.Module):
method __init__ (line 3147) | def __init__(
method forward (line 3196) | def forward(
class KCrossAttnUpBlock2D (line 3235) | class KCrossAttnUpBlock2D(nn.Module):
method __init__ (line 3236) | def __init__(
method forward (line 3321) | def forward(
class KAttentionBlock (line 3383) | class KAttentionBlock(nn.Module):
method __init__ (line 3407) | def __init__(
method _to_3d (line 3450) | def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight...
method _to_4d (line 3453) | def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight...
method forward (line 3456) | def forward(
FILE: foleycrafter/models/auffusion_unet.py
class UNet2DConditionOutput (line 64) | class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel (line 76) | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
method __init__ (line 173) | def __init__(
method load_attention (line 639) | def load_attention(self):
method get_writer_feature (line 652) | def get_writer_feature(self):
method clear_writer_feature (line 655) | def clear_writer_feature(self):
method disable_feature_adapters (line 658) | def disable_feature_adapters(self):
method set_reader_feature (line 661) | def set_reader_feature(self, features: list):
method attn_processors (line 665) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
method set_attn_processor (line 688) | def set_attn_processor(
method set_default_attn_processor (line 724) | def set_default_attn_processor(self):
method set_attention_slice (line 739) | def set_attention_slice(self, slice_size):
method _set_gradient_checkpointing (line 804) | def _set_gradient_checkpointing(self, module, value=False):
method enable_freeu (line 808) | def enable_freeu(self, s1, s2, b1, b2):
method disable_freeu (line 832) | def disable_freeu(self):
method fuse_qkv_projections (line 840) | def fuse_qkv_projections(self):
method unfuse_qkv_projections (line 863) | def unfuse_qkv_projections(self):
method forward (line 876) | def forward(
FILE: foleycrafter/models/onset/r2plus1d_18.py
class r2plus1d18KeepTemp (line 9) | class r2plus1d18KeepTemp(nn.Module):
method __init__ (line 10) | def __init__(self, pretrained=True):
method forward (line 39) | def forward(self, x):
FILE: foleycrafter/models/onset/resnet.py
class Conv3DSimple (line 15) | class Conv3DSimple(nn.Conv3d):
method __init__ (line 16) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
method get_downsample_stride (line 27) | def get_downsample_stride(stride):
class Conv2Plus1D (line 31) | class Conv2Plus1D(nn.Sequential):
method __init__ (line 32) | def __init__(self, in_planes, out_planes, midplanes, stride=1, padding...
method get_downsample_stride (line 55) | def get_downsample_stride(stride):
class Conv3DNoTemporal (line 59) | class Conv3DNoTemporal(nn.Conv3d):
method __init__ (line 60) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
method get_downsample_stride (line 71) | def get_downsample_stride(stride):
class BasicBlock (line 75) | class BasicBlock(nn.Module):
method __init__ (line 78) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
method forward (line 90) | def forward(self, x):
class Bottleneck (line 104) | class Bottleneck(nn.Module):
method __init__ (line 107) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
method forward (line 129) | def forward(self, x):
class BasicStem (line 145) | class BasicStem(nn.Sequential):
method __init__ (line 148) | def __init__(self):
class R2Plus1dStem (line 156) | class R2Plus1dStem(nn.Sequential):
method __init__ (line 159) | def __init__(self):
class VideoResNet (line 170) | class VideoResNet(nn.Module):
method __init__ (line 171) | def __init__(self, block, conv_makers, layers, stem, num_classes=400, ...
method forward (line 202) | def forward(self, x):
method _make_layer (line 221) | def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
method _initialize_weights (line 239) | def _initialize_weights(self):
function _video_resnet (line 253) | def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
function r3d_18 (line 262) | def r3d_18(pretrained=False, progress=True, **kwargs):
function mc3_18 (line 284) | def mc3_18(pretrained=False, progress=True, **kwargs):
function r2plus1d_18 (line 305) | def r2plus1d_18(pretrained=False, progress=True, **kwargs):
FILE: foleycrafter/models/onset/torch_utils.py
function load_model (line 14) | def load_model(cp_path, net, device=None, strict=True):
function binary_acc (line 42) | def binary_acc(pred, target, threshold):
function calc_acc (line 48) | def calc_acc(prob, labels, k):
function get_dataloader (line 57) | def get_dataloader(args, pr, split="train", shuffle=False, drop_last=Fal...
function make_optimizer (line 81) | def make_optimizer(model, args):
function adjust_learning_rate (line 105) | def adjust_learning_rate(optimizer, epoch, args):
FILE: foleycrafter/models/onset/video_onset_net.py
class VideoOnsetNet (line 9) | class VideoOnsetNet(nn.Module):
method __init__ (line 11) | def __init__(self, pretrained):
method forward (line 16) | def forward(self, inputs, loss=False, evaluate=False):
FILE: foleycrafter/models/time_detector/model.py
class TimeDetector (line 6) | class TimeDetector(nn.Module):
method __init__ (line 7) | def __init__(self, video_length=150, audio_length=1024):
method forward (line 13) | def forward(self, inputs):
FILE: foleycrafter/models/time_detector/resnet.py
class Conv3DSimple (line 14) | class Conv3DSimple(nn.Conv3d):
method __init__ (line 15) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
method get_downsample_stride (line 26) | def get_downsample_stride(stride):
class Conv2Plus1D (line 30) | class Conv2Plus1D(nn.Sequential):
method __init__ (line 31) | def __init__(self, in_planes, out_planes, midplanes, stride=1, padding...
method get_downsample_stride (line 54) | def get_downsample_stride(stride):
class Conv3DNoTemporal (line 58) | class Conv3DNoTemporal(nn.Conv3d):
method __init__ (line 59) | def __init__(self, in_planes, out_planes, midplanes=None, stride=1, pa...
method get_downsample_stride (line 70) | def get_downsample_stride(stride):
class BasicBlock (line 74) | class BasicBlock(nn.Module):
method __init__ (line 77) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
method forward (line 89) | def forward(self, x):
class Bottleneck (line 103) | class Bottleneck(nn.Module):
method __init__ (line 106) | def __init__(self, inplanes, planes, conv_builder, stride=1, downsampl...
method forward (line 128) | def forward(self, x):
class BasicStem (line 144) | class BasicStem(nn.Sequential):
method __init__ (line 147) | def __init__(self):
class R2Plus1dStem (line 155) | class R2Plus1dStem(nn.Sequential):
method __init__ (line 158) | def __init__(self):
class VideoResNet (line 169) | class VideoResNet(nn.Module):
method __init__ (line 170) | def __init__(self, block, conv_makers, layers, stem, num_classes=400, ...
method forward (line 201) | def forward(self, x):
method _make_layer (line 220) | def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
method _initialize_weights (line 238) | def _initialize_weights(self):
function _video_resnet (line 252) | def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
function r3d_18 (line 261) | def r3d_18(pretrained=False, progress=True, **kwargs):
function mc3_18 (line 283) | def mc3_18(pretrained=False, progress=True, **kwargs):
function r2plus1d_18 (line 304) | def r2plus1d_18(pretrained=False, progress=True, **kwargs):
FILE: foleycrafter/pipelines/auffusion_pipeline.py
function json_dump (line 66) | def json_dump(data_json, json_save_path):
function json_load (line 72) | def json_load(json_path):
function import_model_class_from_model_name_or_path (line 79) | def import_model_class_from_model_name_or_path(pretrained_model_name_or_...
class ConditionAdapter (line 99) | class ConditionAdapter(nn.Module):
method __init__ (line 100) | def __init__(self, config):
method forward (line 107) | def forward(self, x):
method from_pretrained (line 113) | def from_pretrained(cls, pretrained_model_name_or_path):
method save_pretrained (line 122) | def save_pretrained(self, pretrained_model_name_or_path):
function rescale_noise_cfg (line 131) | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class AttrDict (line 149) | class AttrDict(dict):
method __init__ (line 150) | def __init__(self, *args, **kwargs):
function get_config (line 155) | def get_config(config_path):
function init_weights (line 161) | def init_weights(m, mean=0.0, std=0.01):
function apply_weight_norm (line 167) | def apply_weight_norm(m):
function get_padding (line 173) | def get_padding(kernel_size, dilation=1):
class ResBlock1 (line 177) | class ResBlock1(torch.nn.Module):
method __init__ (line 178) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
method forward (line 232) | def forward(self, x):
method remove_weight_norm (line 241) | def remove_weight_norm(self):
class ResBlock2 (line 248) | class ResBlock2(torch.nn.Module):
method __init__ (line 249) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
method forward (line 278) | def forward(self, x):
method remove_weight_norm (line 285) | def remove_weight_norm(self):
class Generator (line 290) | class Generator(torch.nn.Module):
method __init__ (line 291) | def __init__(self, h):
method device (line 347) | def device(self) -> torch.device:
method dtype (line 351) | def dtype(self):
method forward (line 354) | def forward(self, x):
method remove_weight_norm (line 372) | def remove_weight_norm(self):
method from_pretrained (line 382) | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
method inference (line 398) | def inference(self, mels, lengths=None):
function normalize_spectrogram (line 411) | def normalize_spectrogram(
function denormalize_spectrogram (line 432) | def denormalize_spectrogram(
function pt_to_numpy (line 457) | def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
function numpy_to_pil (line 466) | def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
function image_add_color (line 482) | def image_add_color(spec_img):
class PipelineOutput (line 492) | class PipelineOutput(BaseOutput):
class AuffusionPipeline (line 506) | class AuffusionPipeline(DiffusionPipeline):
method __init__ (line 551) | def __init__(
method from_pretrained (line 588) | def from_pretrained(
method to (line 642) | def to(self, device, dtype=None):
method enable_vae_slicing (line 656) | def enable_vae_slicing(self):
method disable_vae_slicing (line 665) | def disable_vae_slicing(self):
method enable_vae_tiling (line 672) | def enable_vae_tiling(self):
method disable_vae_tiling (line 681) | def disable_vae_tiling(self):
method enable_sequential_cpu_offload (line 688) | def enable_sequential_cpu_offload(self, gpu_id=0):
method enable_model_cpu_offload (line 713) | def enable_model_cpu_offload(self, gpu_id=0):
method _execution_device (line 742) | def _execution_device(self):
method _encode_prompt (line 759) | def _encode_prompt(
method run_safety_checker (line 862) | def run_safety_checker(self, image, device, dtype):
method decode_latents (line 876) | def decode_latents(self, latents):
method prepare_extra_step_kwargs (line 889) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 906) | def check_inputs(
method prepare_latents (line 953) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
method __call__ (line 971) | def __call__(
function retrieve_timesteps (line 1120) | def retrieve_timesteps(
class AuffusionNoAdapterPipeline (line 1164) | class AuffusionNoAdapterPipeline(
method __init__ (line 1205) | def __init__(
method enable_vae_slicing (line 1297) | def enable_vae_slicing(self):
method disable_vae_slicing (line 1304) | def disable_vae_slicing(self):
method enable_vae_tiling (line 1311) | def enable_vae_tiling(self):
method disable_vae_tiling (line 1319) | def disable_vae_tiling(self):
method _encode_prompt (line 1326) | def _encode_prompt(
method encode_prompt (line 1358) | def encode_prompt(
method prepare_ip_adapter_image_embeds (line 1529) | def prepare_ip_adapter_image_embeds(
method encode_image (line 1580) | def encode_image(self, image, device, num_images_per_prompt, output_hi...
method run_safety_checker (line 1604) | def run_safety_checker(self, image, device, dtype):
method decode_latents (line 1618) | def decode_latents(self, latents):
method prepare_extra_step_kwargs (line 1629) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 1646) | def check_inputs(
method prepare_latents (line 1698) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
method enable_freeu (line 1715) | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
method disable_freeu (line 1737) | def disable_freeu(self):
method fuse_qkv_projections (line 1742) | def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
method unfuse_qkv_projections (line 1774) | def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
method get_guidance_scale_embedding (line 1803) | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=tor...
method guidance_scale (line 1832) | def guidance_scale(self):
method guidance_rescale (line 1836) | def guidance_rescale(self):
method clip_skip (line 1840) | def clip_skip(self):
method do_classifier_free_guidance (line 1847) | def do_classifier_free_guidance(self):
method cross_attention_kwargs (line 1851) | def cross_attention_kwargs(self):
method num_timesteps (line 1855) | def num_timesteps(self):
method interrupt (line 1859) | def interrupt(self):
method __call__ (line 1863) | def __call__(
FILE: foleycrafter/pipelines/pipeline_controlnet.py
function retrieve_timesteps (line 97) | def retrieve_timesteps(
class StableDiffusionControlNetPipeline (line 141) | class StableDiffusionControlNetPipeline(
method __init__ (line 186) | def __init__(
method enable_vae_slicing (line 239) | def enable_vae_slicing(self):
method disable_vae_slicing (line 247) | def disable_vae_slicing(self):
method enable_vae_tiling (line 255) | def enable_vae_tiling(self):
method disable_vae_tiling (line 264) | def disable_vae_tiling(self):
method _encode_prompt (line 272) | def _encode_prompt(
method encode_prompt (line 305) | def encode_prompt(
method prepare_ip_adapter_image_embeds (line 486) | def prepare_ip_adapter_image_embeds(
method encode_image (line 538) | def encode_image(self, image, device, num_images_per_prompt, output_hi...
method run_safety_checker (line 563) | def run_safety_checker(self, image, device, dtype):
method decode_latents (line 578) | def decode_latents(self, latents):
method prepare_extra_step_kwargs (line 590) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 607) | def check_inputs(
method check_image (line 753) | def check_image(self, image, prompt, prompt_embeds):
method prepare_image (line 790) | def prepare_image(
method prepare_latents (line 821) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
method enable_freeu (line 839) | def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
method disable_freeu (line 862) | def disable_freeu(self):
method get_guidance_scale_embedding (line 867) | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=tor...
method guidance_scale (line 896) | def guidance_scale(self):
method clip_skip (line 900) | def clip_skip(self):
method do_classifier_free_guidance (line 907) | def do_classifier_free_guidance(self):
method cross_attention_kwargs (line 911) | def cross_attention_kwargs(self):
method num_timesteps (line 915) | def num_timesteps(self):
method __call__ (line 920) | def __call__(
FILE: foleycrafter/utils/converter.py
function load_wav (line 23) | def load_wav(full_path):
function dynamic_range_compression (line 28) | def dynamic_range_compression(x, C=1, clip_val=1e-5):
function dynamic_range_decompression (line 32) | def dynamic_range_decompression(x, C=1):
function dynamic_range_compression_torch (line 36) | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
function dynamic_range_decompression_torch (line 40) | def dynamic_range_decompression_torch(x, C=1):
function spectral_normalize_torch (line 44) | def spectral_normalize_torch(magnitudes):
function spectral_de_normalize_torch (line 49) | def spectral_de_normalize_torch(magnitudes):
function mel_spectrogram (line 58) | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_siz...
function spectrogram (line 97) | def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, f...
function normalize_spectrogram (line 130) | def normalize_spectrogram(
function denormalize_spectrogram (line 161) | def denormalize_spectrogram(
function get_mel_spectrogram_from_audio (line 195) | def get_mel_spectrogram_from_audio(audio, device="cpu"):
class AttrDict (line 222) | class AttrDict(dict):
method __init__ (line 223) | def __init__(self, *args, **kwargs):
function get_config (line 228) | def get_config(config_path):
function init_weights (line 234) | def init_weights(m, mean=0.0, std=0.01):
function apply_weight_norm (line 240) | def apply_weight_norm(m):
function get_padding (line 246) | def get_padding(kernel_size, dilation=1):
class ResBlock1 (line 250) | class ResBlock1(torch.nn.Module):
method __init__ (line 251) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
method forward (line 305) | def forward(self, x):
method remove_weight_norm (line 314) | def remove_weight_norm(self):
class ResBlock2 (line 321) | class ResBlock2(torch.nn.Module):
method __init__ (line 322) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
method forward (line 351) | def forward(self, x):
method remove_weight_norm (line 358) | def remove_weight_norm(self):
class Generator (line 363) | class Generator(torch.nn.Module):
method __init__ (line 364) | def __init__(self, h):
method forward (line 416) | def forward(self, x):
method remove_weight_norm (line 434) | def remove_weight_norm(self):
method from_pretrained (line 443) | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
method inference (line 459) | def inference(self, mels, lengths=None):
function normalize (line 472) | def normalize(images):
function pad_spec (line 482) | def pad_spec(spec, spec_length, pad_value=0, random_crop=True): # spec:...
FILE: foleycrafter/utils/spec_to_mel.py
class STFT (line 14) | class STFT(torch.nn.Module):
method __init__ (line 17) | def __init__(self, filter_length, hop_length, win_length, window="hann"):
method transform (line 47) | def transform(self, input_data):
method inverse (line 78) | def inverse(self, magnitude, phase):
method forward (line 111) | def forward(self, input_data):
function window_sumsquare (line 117) | def window_sumsquare(
function griffin_lim (line 176) | def griffin_lim(magnitudes, stft_fn, n_iters=30):
function dynamic_range_compression (line 195) | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=...
function dynamic_range_decompression (line 204) | def dynamic_range_decompression(x, C=1):
class TacotronSTFT (line 213) | class TacotronSTFT(torch.nn.Module):
method __init__ (line 214) | def __init__(
method spectral_normalize (line 232) | def spectral_normalize(self, magnitudes, normalize_fun):
method spectral_de_normalize (line 236) | def spectral_de_normalize(self, magnitudes):
method mel_spectrogram (line 240) | def mel_spectrogram(self, y, normalize_fun=torch.log):
function pad_wav (line 264) | def pad_wav(waveform, segment_length):
function normalize_wav (line 277) | def normalize_wav(waveform):
function _pad_spec (line 283) | def _pad_spec(fbank, target_length=1024):
function get_mel_from_wav (line 299) | def get_mel_from_wav(audio, _stft):
function read_wav_file_io (line 309) | def read_wav_file_io(bytes):
function load_audio (line 323) | def load_audio(bytes, sample_rate=16000):
function read_wav_file (line 329) | def read_wav_file(filename):
function norm_wav_tensor (line 343) | def norm_wav_tensor(waveform: torch.FloatTensor):
function wav_to_fbank (line 352) | def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
function wav_tensor_to_fbank (line 380) | def wav_tensor_to_fbank(waveform, target_length=512, fn_STFT=None):
FILE: foleycrafter/utils/util.py
function zero_rank_print (line 35) | def zero_rank_print(s):
function build_foleycrafter (line 40) | def build_foleycrafter(
function save_videos_grid (line 66) | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_r...
function save_videos_from_pil_list (line 82) | def save_videos_from_pil_list(videos: list, path: str, fps=7):
function seed_everything (line 89) | def seed_everything(seed: int) -> None:
function get_video_frames (line 102) | def get_video_frames(video: np.ndarray, num_frames: int = 200):
function random_audio_video_clip (line 109) | def random_audio_video_clip(
function get_full_indices (line 142) | def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReade...
function get_frames (line 149) | def get_frames(video_path: str, onset_list, frame_nums=1024):
function get_frames_in_video (line 163) | def get_frames_in_video(video_path: str, onset_list, frame_nums=1024, au...
function save_multimodal (line 183) | def save_multimodal(video, audio, output_path, audio_fps: int = 16000, v...
function save_multimodal_by_frame (line 203) | def save_multimodal_by_frame(video, audio, output_path, audio_fps: int =...
function sanity_check (line 221) | def sanity_check(data: dict, save_path: str = "sanity_check", batch_size...
function video_tensor_to_np (line 242) | def video_tensor_to_np(video: torch.Tensor, rescale: bool = True, scale:...
function composite_audio_video (line 255) | def composite_audio_video(video: str, audio: str, path: str, video_fps: ...
function append_dims (line 265) | def append_dims(x, target_dims):
function resize_with_antialiasing (line 273) | def resize_with_antialiasing(input, size, interpolation="bicubic", align...
function _gaussian_blur2d (line 302) | def _gaussian_blur2d(input, kernel_size, sigma):
function _filter2d (line 318) | def _filter2d(input, kernel):
function _gaussian (line 341) | def _gaussian(window_size: int, sigma):
function _compute_padding (line 357) | def _compute_padding(kernel_size):
function print_gpu_memory_usage (line 380) | def print_gpu_memory_usage(info: str, cuda_id: int = 0):
class SpectrogramParams (line 392) | class SpectrogramParams:
class ExifTags (line 427) | class ExifTags(Enum):
method n_fft (line 446) | def n_fft(self) -> int:
method win_length (line 453) | def win_length(self) -> int:
method hop_length (line 460) | def hop_length(self) -> int:
method to_exif (line 466) | def to_exif(self) -> T.Dict[int, T.Any]:
class SpectrogramImageConverter (line 483) | class SpectrogramImageConverter:
method __init__ (line 491) | def __init__(self, params: SpectrogramParams, device: str = "cuda"):
method spectrogram_image_from_audio (line 496) | def spectrogram_image_from_audio(
method audio_from_spectrogram_image (line 538) | def audio_from_spectrogram_image(
function image_from_spectrogram (line 567) | def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25)...
function spectrogram_from_image (line 613) | def spectrogram_from_image(
class SpectrogramConverter (line 667) | class SpectrogramConverter:
method __init__ (line 689) | def __init__(self, params: SpectrogramParams, device: str = "cuda"):
method spectrogram_from_audio (line 756) | def spectrogram_from_audio(
method audio_from_spectrogram (line 782) | def audio_from_spectrogram(
method mel_amplitudes_from_waveform (line 820) | def mel_amplitudes_from_waveform(
method waveform_from_mel_amplitudes (line 842) | def waveform_from_mel_amplitudes(
function check_device (line 862) | def check_device(device: str, backup: str = "cpu") -> str:
function audio_from_waveform (line 876) | def audio_from_waveform(samples: np.ndarray, sample_rate: int, normalize...
function apply_filters_func (line 900) | def apply_filters_func(segment: pydub.AudioSegment, compression: bool = ...
function shave_segments (line 936) | def shave_segments(path, n_shave_prefix_segments=1):
function renew_resnet_paths (line 946) | def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
function renew_vae_resnet_paths (line 968) | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
function renew_attention_paths (line 984) | def renew_attention_paths(old_list, n_shave_prefix_segments=0):
function renew_vae_attention_paths (line 1005) | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
function assign_to_checkpoint (line 1034) | def assign_to_checkpoint(
function conv_attn_to_linear (line 1089) | def conv_attn_to_linear(checkpoint):
function create_unet_diffusers_config (line 1101) | def create_unet_diffusers_config(original_config, image_size: int, contr...
function create_vae_diffusers_config (line 1170) | def create_vae_diffusers_config(original_config, image_size: int):
function create_diffusers_schedular (line 1194) | def create_diffusers_schedular(original_config):
function convert_ldm_unet_checkpoint (line 1204) | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_e...
function convert_ldm_vae_checkpoint (line 1435) | def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, o...
function convert_ldm_clip_checkpoint (line 1550) | def convert_ldm_clip_checkpoint(checkpoint):
function convert_lora_model_level (line 1561) | def convert_lora_model_level(
function denormalize_spectrogram (line 1643) | def denormalize_spectrogram(
class ToTensor1D (line 1677) | class ToTensor1D(torchvision.transforms.ToTensor):
method __call__ (line 1678) | def __call__(self, tensor: np.ndarray):
function scale (line 1684) | def scale(old_value, old_min, old_max, new_min, new_max):
function read_frames_with_moviepy (line 1692) | def read_frames_with_moviepy(video_path, max_frame_nums=None):
function read_frames_with_moviepy_resample (line 1703) | def read_frames_with_moviepy_resample(video_path, save_path):
FILE: inference.py
function args_parse (line 28) | def args_parse():
function build_models (line 49) | def build_models(config):
function run_inference (line 97) | def run_inference(config, pipe, vocoder, time_detector):
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (919K chars).
[
{
"path": ".gitignore",
"chars": 153,
"preview": "*.ckpt\n*.pt\n*.pyc\n*.safetensors\n\n__pycache__/\noutput/\ncheckpoints/\ntrain/\nconfigs/\n\n*.wav\n*.mp3\n*.gif\n*.jpg\n*.png\n*.log\n"
},
{
"path": ".pre-commit-config.yaml",
"chars": 651,
"preview": "repos:\n - repo: https://github.com/astral-sh/ruff-pre-commit\n # Ruff version.\n rev: v0.3.5\n hooks:\n # Run"
},
{
"path": "LICENSE",
"chars": 11355,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 18131,
"preview": "<p align=\"center\">\n<img src='assets/foleycrafter.png' style=\"text-align: center; width: 134px\" >\n</p>\n\n<div align=\"cente"
},
{
"path": "app.py",
"chars": 13846,
"preview": "import os\nimport os.path as osp\nimport random\nfrom argparse import ArgumentParser\nfrom datetime import datetime\n\nimport "
},
{
"path": "foleycrafter/data/__init__.py",
"chars": 656,
"preview": "from .dataset import AudioSetStrong, CPU_Unpickler, VGGSound, dynamic_range_compression, get_mel, zero_rank_print\nfrom ."
},
{
"path": "foleycrafter/data/dataset.py",
"chars": 4801,
"preview": "import glob\nimport io\nimport pickle\nimport random\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nimpo"
},
{
"path": "foleycrafter/data/video_transforms.py",
"chars": 12623,
"preview": "import numbers\nimport random\n\nimport torch\n\n\ndef _is_tensor_video_clip(clip):\n if not torch.is_tensor(clip):\n "
},
{
"path": "foleycrafter/models/adapters/attention_processor.py",
"chars": 24309,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diffusers.utils import logging\n\n\nlogger = loggi"
},
{
"path": "foleycrafter/models/adapters/ip_adapter.py",
"chars": 7195,
"preview": "import torch\nimport torch.nn as nn\n\n\nclass IPAdapter(torch.nn.Module):\n \"\"\"IP-Adapter\"\"\"\n\n def __init__(self, unet"
},
{
"path": "foleycrafter/models/adapters/resampler.py",
"chars": 5059,
"preview": "# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py\n# and https://gith"
},
{
"path": "foleycrafter/models/adapters/transformer.py",
"chars": 13691,
"preview": "from typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\n\nclass Attention(n"
},
{
"path": "foleycrafter/models/adapters/utils.py",
"chars": 2501,
"preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\n\n\nattn_maps = {}\n\n\ndef hook_fn(nam"
},
{
"path": "foleycrafter/models/auffusion/attention.py",
"chars": 27845,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion/attention_processor.py",
"chars": 116323,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion/dual_transformer_2d.py",
"chars": 7688,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion/loaders/ip_adapter.py",
"chars": 27447,
"preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion/loaders/unet.py",
"chars": 46482,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion/resnet.py",
"chars": 26825,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The"
},
{
"path": "foleycrafter/models/auffusion/transformer_2d.py",
"chars": 24001,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion/unet_2d_blocks.py",
"chars": 135768,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/models/auffusion_unet.py",
"chars": 65826,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
},
{
"path": "foleycrafter/models/onset/__init__.py",
"chars": 139,
"preview": "from .r2plus1d_18 import r2plus1d18KeepTemp\nfrom .video_onset_net import VideoOnsetNet\n\n\n__all__ = [\"r2plus1d18KeepTemp\""
},
{
"path": "foleycrafter/models/onset/r2plus1d_18.py",
"chars": 1885,
"preview": "# Copied from specvqgan/onset_baseline/models/r2plus1d_18.py\n\nimport torch\nimport torch.nn as nn\n\nfrom .resnet import r2"
},
{
"path": "foleycrafter/models/onset/resnet.py",
"chars": 10499,
"preview": "# Copied from specvqgan/onset_baseline/models/resnet.py\nimport torch.nn as nn\nfrom torch.hub import load_state_dict_from"
},
{
"path": "foleycrafter/models/onset/torch_utils.py",
"chars": 3529,
"preview": "# Copied from https://github.com/XYPB/CondFoleyGen/blob/main/specvqgan/onset_baseline/utils/torch_utils.py\nimport os\nimp"
},
{
"path": "foleycrafter/models/onset/video_onset_net.py",
"chars": 883,
"preview": "# Copied from specvqgan/onset_baseline/models/video_onset_net.py\n\nimport torch\nimport torch.nn as nn\n\nfrom .r2plus1d_18 "
},
{
"path": "foleycrafter/models/time_detector/model.py",
"chars": 491,
"preview": "import torch.nn as nn\n\nfrom ..onset import VideoOnsetNet\n\n\nclass TimeDetector(nn.Module):\n def __init__(self, video_l"
},
{
"path": "foleycrafter/models/time_detector/resnet.py",
"chars": 10443,
"preview": "import torch.nn as nn\nfrom torch.hub import load_state_dict_from_url\n\n\n__all__ = [\"r3d_18\", \"mc3_18\", \"r2plus1d_18\"]\n\nmo"
},
{
"path": "foleycrafter/pipelines/auffusion_pipeline.py",
"chars": 97874,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/pipelines/pipeline_controlnet.py",
"chars": 68360,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "foleycrafter/utils/converter.py",
"chars": 14768,
"preview": "# Copy from https://github.com/happylittlecat2333/Auffusion/blob/main/converter.py\nimport json\nimport os\nimport random\n\n"
},
{
"path": "foleycrafter/utils/spec_to_mel.py",
"chars": 13197,
"preview": "import librosa.util as librosa_util\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchaudio\nfr"
},
{
"path": "foleycrafter/utils/util.py",
"chars": 65887,
"preview": "import glob\nimport io\nimport os\nimport os.path as osp\nimport random\nimport typing as T\nimport warnings\nfrom dataclasses "
},
{
"path": "inference.py",
"chars": 7717,
"preview": "import argparse\nimport glob\nimport os\nimport os.path as osp\nfrom pathlib import Path\n\nimport soundfile as sf\nimport torc"
},
{
"path": "pyproject.toml",
"chars": 749,
"preview": "[tool.ruff]\n# Never enforce `E501` (line length violations).\nignore = [\"C901\", \"E501\", \"E741\", \"F402\", \"F823\"]\nselect = "
},
{
"path": "requirements/environment.yaml",
"chars": 444,
"preview": "name: foleycrafter\nchannels:\n - pytorch\n - nvidia\ndependencies:\n - python=3.10\n - pytorch=2.2.0\n - torchvision=0.17"
}
]
About this extraction
This page contains the full source code of the open-mmlab/FoleyCrafter GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (869.2 KB), approximately 201.2k tokens, and a symbol index with 700 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.